Spaces:
Runtime error
Runtime error
| # app_utils.py | |
| # This file will contain the refactored core logic for training and prediction. | |
| import os | |
| import time | |
| import pickle | |
| import torch | |
| import gradio as gr | |
| from torch import nn | |
| from torch import optim | |
| from torch.optim import lr_scheduler | |
| from model.config import load_config | |
| from model.genconvit_ed import GenConViTED | |
| from model.genconvit_vae import GenConViTVAE | |
| from dataset.loader import load_data, load_checkpoint | |
| from model.pred_func import set_result, load_genconvit, df_face, pred_vid, real_or_fake | |
| # Load configuration | |
| config = load_config() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def get_available_weights(weight_dir="weight"): | |
| """Scans the weight directory for .pth files.""" | |
| if not os.path.exists(weight_dir): | |
| os.makedirs(weight_dir) | |
| weights = [f for f in os.listdir(weight_dir) if f.endswith(".pth")] | |
| return weights if weights else ["No weights found"] | |
| def count_files_in_subdirs(directory): | |
| """Counts files in the 'real' and 'fake' subdirectories of a given directory.""" | |
| real_path = os.path.join(directory, 'real') | |
| fake_path = os.path.join(directory, 'fake') | |
| real_count = 0 | |
| if os.path.exists(real_path) and os.path.isdir(real_path): | |
| real_count = len([name for name in os.listdir(real_path) if os.path.isfile(os.path.join(real_path, name))]) | |
| fake_count = 0 | |
| if os.path.exists(fake_path) and os.path.isdir(fake_path): | |
| fake_count = len([name for name in os.listdir(fake_path) if os.path.isfile(os.path.join(fake_path, name))]) | |
| return f"Real: {real_count}, Fake: {fake_count}" | |
| def get_dataset_counts(): | |
| """Gets the file counts for train, validation, and test sets.""" | |
| train_counts = count_files_in_subdirs('train') | |
| valid_counts = count_files_in_subdirs('valid') | |
| test_counts = count_files_in_subdirs('test') | |
| return train_counts, valid_counts, test_counts | |
| def train_model_gradio(model_variant, ed_pretrained, vae_pretrained, epochs, batch_size, run_test, use_fp16, progress=gr.Progress()): | |
| """Refactored training function for Gradio UI.""" | |
| dir_path = './' | |
| if not (os.path.exists('train') and os.path.exists('valid')): | |
| yield "Error: 'train' and 'valid' directories not found. Please create them and populate them with 'real' and 'fake' subdirectories." | |
| return | |
| yield "Loading data..." | |
| progress(0, desc="Loading data...") | |
| try: | |
| dataloaders, dataset_sizes = load_data(dir_path, int(batch_size)) | |
| yield "Data loaded." | |
| except Exception as e: | |
| yield f"Error loading data: {e}. Please ensure the dataset is structured correctly." | |
| return | |
| models = [] | |
| optimizers = [] | |
| if model_variant in ["AE", "AE & VAE"]: | |
| yield "Initializing AE model..." | |
| model_ed = GenConViTED(config) | |
| optimizer_ed = optim.Adam(model_ed.parameters(), lr=float(config["learning_rate"]), weight_decay=float(config["weight_decay"])) | |
| if ed_pretrained and ed_pretrained != "No weights found": | |
| try: | |
| model_ed, optimizer_ed, _, _ = load_checkpoint(model_ed, optimizer_ed, filename=os.path.join("weight", ed_pretrained)) | |
| except Exception as e: | |
| yield f"Error loading ED checkpoint: {e}" | |
| models.append(("ed", model_ed, optimizer_ed)) | |
| if model_variant in ["VAE", "AE & VAE"]: | |
| yield "Initializing VAE model..." | |
| model_vae = GenConViTVAE(config) | |
| optimizer_vae = optim.Adam(model_vae.parameters(), lr=float(config["learning_rate"]), weight_decay=float(config["weight_decay"])) | |
| if vae_pretrained and vae_pretrained != "No weights found": | |
| try: | |
| model_vae, optimizer_vae, _, _ = load_checkpoint(model_vae, optimizer_vae, filename=os.path.join("weight", vae_pretrained)) | |
| except Exception as e: | |
| yield f"Error loading VAE checkpoint: {e}" | |
| models.append(("vae", model_vae, optimizer_vae)) | |
| for mod, model, optimizer in models: | |
| yield f"Starting training for {mod.upper()} model..." | |
| criterion = nn.CrossEntropyLoss().to(device) | |
| mse = nn.MSELoss() | |
| scheduler = lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.1) | |
| model.to(device) | |
| train_loss, train_acc, valid_loss, valid_acc = [], [], [], [] | |
| train_func, valid_func = None, None | |
| if mod == 'ed': | |
| from train.train_ed import train as train_func_ed, valid as valid_func_ed | |
| train_func, valid_func = train_func_ed, valid_func_ed | |
| else: | |
| from train.train_vae import train as train_func_vae, valid as valid_func_vae | |
| train_func, valid_func = train_func_vae, valid_func_vae | |
| for epoch in range(int(epochs)): | |
| epoch_desc = f"Epoch {epoch+1}/{int(epochs)} ({mod.upper()})" | |
| progress(epoch / int(epochs), desc=epoch_desc) | |
| yield f"{epoch_desc} - Training..." | |
| epoch_loss, epoch_acc = 0,0 | |
| try: | |
| train_loss, train_acc, epoch_loss = train_func(model, device, dataloaders["train"], criterion, optimizer, epoch, train_loss, train_acc, mse) | |
| except Exception as e: | |
| yield f"Error during training: {e}" | |
| break | |
| yield f"{epoch_desc} - Validation..." | |
| try: | |
| valid_loss, valid_acc = valid_func(model, device, dataloaders["validation"], criterion, epoch, valid_loss, valid_acc, mse) | |
| yield f"Epoch {epoch+1} complete for {mod.upper()}. Validation Loss: {valid_loss[-1]:.4f}, Validation Acc: {valid_acc[-1]:.4f}" | |
| except Exception as e: | |
| yield f"Error during validation: {e}" | |
| break | |
| scheduler.step() | |
| yield f"Training complete for {mod.upper()}. Saving model..." | |
| progress(1, desc=f"Saving {mod.upper()} model...") | |
| file_path = os.path.join("weight", f'genconvit_{mod}_{time.strftime("%b_%d_%Y_%H_%M_%S", time.localtime())}') | |
| with open(f"{file_path}.pkl", "wb") as f: | |
| pickle.dump([train_loss, train_acc, valid_loss, valid_acc], f) | |
| state = { | |
| "epoch": epochs, "state_dict": model.state_dict(), | |
| "optimizer": optimizer.state_dict(), "min_loss": epoch_loss, | |
| } | |
| weight_filename = f"{file_path}.pth" | |
| torch.save(state, weight_filename) | |
| yield f"Model saved to {weight_filename}" | |
| if run_test: | |
| yield f"Running test for {mod.upper()} model..." | |
| # test() function from train.py needs to be refactored to be callable here | |
| pass | |
| yield "All training processes finished." | |
| def predict_video_gradio(video_path, ed_weight, vae_weight, num_frames, use_fp16, progress=gr.Progress()): | |
| """Refactored prediction function for Gradio UI.""" | |
| if not video_path: | |
| return "Please upload a video.", "", "", "" | |
| net_type = None | |
| ed_weight_path, vae_weight_path = None, None | |
| if ed_weight and ed_weight != "No weights found": | |
| ed_weight_path = os.path.join("weight", ed_weight) | |
| if vae_weight and vae_weight != "No weights found": | |
| vae_weight_path = os.path.join("weight", vae_weight) | |
| if ed_weight_path and vae_weight_path: | |
| net_type = 'genconvit' | |
| elif ed_weight_path: | |
| net_type = 'ed' | |
| elif vae_weight_path: | |
| net_type = 'vae' | |
| else: | |
| return "Status: Error", "Please select at least one model weight.", "" | |
| yield "Status: Loading model...", "", "" | |
| progress(0.1, desc="Loading model...") | |
| try: | |
| model = load_genconvit(config, net_type, ed_weight_path, vae_weight_path, use_fp16) | |
| except Exception as e: | |
| return f"Status: Error loading model", f"Details: {e}", "" | |
| yield "Status: Model loaded. Extracting faces...", "", "" | |
| progress(0.3, desc="Extracting faces...") | |
| try: | |
| faces = df_face(video_path, int(num_frames)) | |
| if len(faces) == 0: | |
| return "Status: Error", "No faces detected in the video.", "" | |
| except Exception as e: | |
| return "Status: Error during face extraction", f"Details: {e}. Is dlib installed correctly?", "" | |
| yield f"Status: {len(faces)} face(s) detected. Running prediction...", "", "" | |
| progress(0.8, desc="Running prediction...") | |
| try: | |
| y, y_val = pred_vid(faces, model) | |
| label = real_or_fake(y) | |
| score = y_val if label == "REAL" else 1 - y_val | |
| confidence_str = f"{score*100:.2f}%" | |
| progress(1, desc="Prediction complete") | |
| return f"Status: Prediction complete.", label, confidence_str | |
| except Exception as e: | |
| return "Status: Error during prediction", f"Details: {e}", "" | |