Spaces:
Runtime error
Runtime error
File size: 8,683 Bytes
e0c75d6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 | # app_utils.py
# This file will contain the refactored core logic for training and prediction.
import os
import time
import pickle
import torch
import gradio as gr
from torch import nn
from torch import optim
from torch.optim import lr_scheduler
from model.config import load_config
from model.genconvit_ed import GenConViTED
from model.genconvit_vae import GenConViTVAE
from dataset.loader import load_data, load_checkpoint
from model.pred_func import set_result, load_genconvit, df_face, pred_vid, real_or_fake
# Load configuration
config = load_config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def get_available_weights(weight_dir="weight"):
"""Scans the weight directory for .pth files."""
if not os.path.exists(weight_dir):
os.makedirs(weight_dir)
weights = [f for f in os.listdir(weight_dir) if f.endswith(".pth")]
return weights if weights else ["No weights found"]
def count_files_in_subdirs(directory):
"""Counts files in the 'real' and 'fake' subdirectories of a given directory."""
real_path = os.path.join(directory, 'real')
fake_path = os.path.join(directory, 'fake')
real_count = 0
if os.path.exists(real_path) and os.path.isdir(real_path):
real_count = len([name for name in os.listdir(real_path) if os.path.isfile(os.path.join(real_path, name))])
fake_count = 0
if os.path.exists(fake_path) and os.path.isdir(fake_path):
fake_count = len([name for name in os.listdir(fake_path) if os.path.isfile(os.path.join(fake_path, name))])
return f"Real: {real_count}, Fake: {fake_count}"
def get_dataset_counts():
"""Gets the file counts for train, validation, and test sets."""
train_counts = count_files_in_subdirs('train')
valid_counts = count_files_in_subdirs('valid')
test_counts = count_files_in_subdirs('test')
return train_counts, valid_counts, test_counts
def train_model_gradio(model_variant, ed_pretrained, vae_pretrained, epochs, batch_size, run_test, use_fp16, progress=gr.Progress()):
"""Refactored training function for Gradio UI."""
dir_path = './'
if not (os.path.exists('train') and os.path.exists('valid')):
yield "Error: 'train' and 'valid' directories not found. Please create them and populate them with 'real' and 'fake' subdirectories."
return
yield "Loading data..."
progress(0, desc="Loading data...")
try:
dataloaders, dataset_sizes = load_data(dir_path, int(batch_size))
yield "Data loaded."
except Exception as e:
yield f"Error loading data: {e}. Please ensure the dataset is structured correctly."
return
models = []
optimizers = []
if model_variant in ["AE", "AE & VAE"]:
yield "Initializing AE model..."
model_ed = GenConViTED(config)
optimizer_ed = optim.Adam(model_ed.parameters(), lr=float(config["learning_rate"]), weight_decay=float(config["weight_decay"]))
if ed_pretrained and ed_pretrained != "No weights found":
try:
model_ed, optimizer_ed, _, _ = load_checkpoint(model_ed, optimizer_ed, filename=os.path.join("weight", ed_pretrained))
except Exception as e:
yield f"Error loading ED checkpoint: {e}"
models.append(("ed", model_ed, optimizer_ed))
if model_variant in ["VAE", "AE & VAE"]:
yield "Initializing VAE model..."
model_vae = GenConViTVAE(config)
optimizer_vae = optim.Adam(model_vae.parameters(), lr=float(config["learning_rate"]), weight_decay=float(config["weight_decay"]))
if vae_pretrained and vae_pretrained != "No weights found":
try:
model_vae, optimizer_vae, _, _ = load_checkpoint(model_vae, optimizer_vae, filename=os.path.join("weight", vae_pretrained))
except Exception as e:
yield f"Error loading VAE checkpoint: {e}"
models.append(("vae", model_vae, optimizer_vae))
for mod, model, optimizer in models:
yield f"Starting training for {mod.upper()} model..."
criterion = nn.CrossEntropyLoss().to(device)
mse = nn.MSELoss()
scheduler = lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.1)
model.to(device)
train_loss, train_acc, valid_loss, valid_acc = [], [], [], []
train_func, valid_func = None, None
if mod == 'ed':
from train.train_ed import train as train_func_ed, valid as valid_func_ed
train_func, valid_func = train_func_ed, valid_func_ed
else:
from train.train_vae import train as train_func_vae, valid as valid_func_vae
train_func, valid_func = train_func_vae, valid_func_vae
for epoch in range(int(epochs)):
epoch_desc = f"Epoch {epoch+1}/{int(epochs)} ({mod.upper()})"
progress(epoch / int(epochs), desc=epoch_desc)
yield f"{epoch_desc} - Training..."
epoch_loss, epoch_acc = 0,0
try:
train_loss, train_acc, epoch_loss = train_func(model, device, dataloaders["train"], criterion, optimizer, epoch, train_loss, train_acc, mse)
except Exception as e:
yield f"Error during training: {e}"
break
yield f"{epoch_desc} - Validation..."
try:
valid_loss, valid_acc = valid_func(model, device, dataloaders["validation"], criterion, epoch, valid_loss, valid_acc, mse)
yield f"Epoch {epoch+1} complete for {mod.upper()}. Validation Loss: {valid_loss[-1]:.4f}, Validation Acc: {valid_acc[-1]:.4f}"
except Exception as e:
yield f"Error during validation: {e}"
break
scheduler.step()
yield f"Training complete for {mod.upper()}. Saving model..."
progress(1, desc=f"Saving {mod.upper()} model...")
file_path = os.path.join("weight", f'genconvit_{mod}_{time.strftime("%b_%d_%Y_%H_%M_%S", time.localtime())}')
with open(f"{file_path}.pkl", "wb") as f:
pickle.dump([train_loss, train_acc, valid_loss, valid_acc], f)
state = {
"epoch": epochs, "state_dict": model.state_dict(),
"optimizer": optimizer.state_dict(), "min_loss": epoch_loss,
}
weight_filename = f"{file_path}.pth"
torch.save(state, weight_filename)
yield f"Model saved to {weight_filename}"
if run_test:
yield f"Running test for {mod.upper()} model..."
# test() function from train.py needs to be refactored to be callable here
pass
yield "All training processes finished."
def predict_video_gradio(video_path, ed_weight, vae_weight, num_frames, use_fp16, progress=gr.Progress()):
"""Refactored prediction function for Gradio UI."""
if not video_path:
return "Please upload a video.", "", "", ""
net_type = None
ed_weight_path, vae_weight_path = None, None
if ed_weight and ed_weight != "No weights found":
ed_weight_path = os.path.join("weight", ed_weight)
if vae_weight and vae_weight != "No weights found":
vae_weight_path = os.path.join("weight", vae_weight)
if ed_weight_path and vae_weight_path:
net_type = 'genconvit'
elif ed_weight_path:
net_type = 'ed'
elif vae_weight_path:
net_type = 'vae'
else:
return "Status: Error", "Please select at least one model weight.", ""
yield "Status: Loading model...", "", ""
progress(0.1, desc="Loading model...")
try:
model = load_genconvit(config, net_type, ed_weight_path, vae_weight_path, use_fp16)
except Exception as e:
return f"Status: Error loading model", f"Details: {e}", ""
yield "Status: Model loaded. Extracting faces...", "", ""
progress(0.3, desc="Extracting faces...")
try:
faces = df_face(video_path, int(num_frames))
if len(faces) == 0:
return "Status: Error", "No faces detected in the video.", ""
except Exception as e:
return "Status: Error during face extraction", f"Details: {e}. Is dlib installed correctly?", ""
yield f"Status: {len(faces)} face(s) detected. Running prediction...", "", ""
progress(0.8, desc="Running prediction...")
try:
y, y_val = pred_vid(faces, model)
label = real_or_fake(y)
score = y_val if label == "REAL" else 1 - y_val
confidence_str = f"{score*100:.2f}%"
progress(1, desc="Prediction complete")
return f"Status: Prediction complete.", label, confidence_str
except Exception as e:
return "Status: Error during prediction", f"Details: {e}", ""
|