File size: 8,683 Bytes
e0c75d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
# app_utils.py
# This file will contain the refactored core logic for training and prediction.
import os
import time
import pickle
import torch
import gradio as gr
from torch import nn
from torch import optim
from torch.optim import lr_scheduler

from model.config import load_config
from model.genconvit_ed import GenConViTED
from model.genconvit_vae import GenConViTVAE
from dataset.loader import load_data, load_checkpoint
from model.pred_func import set_result, load_genconvit, df_face, pred_vid, real_or_fake

# Load configuration
config = load_config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_available_weights(weight_dir="weight"):
    """Scans the weight directory for .pth files."""
    if not os.path.exists(weight_dir):
        os.makedirs(weight_dir)
    weights = [f for f in os.listdir(weight_dir) if f.endswith(".pth")]
    return weights if weights else ["No weights found"]

def count_files_in_subdirs(directory):
    """Counts files in the 'real' and 'fake' subdirectories of a given directory."""
    real_path = os.path.join(directory, 'real')
    fake_path = os.path.join(directory, 'fake')

    real_count = 0
    if os.path.exists(real_path) and os.path.isdir(real_path):
        real_count = len([name for name in os.listdir(real_path) if os.path.isfile(os.path.join(real_path, name))])

    fake_count = 0
    if os.path.exists(fake_path) and os.path.isdir(fake_path):
        fake_count = len([name for name in os.listdir(fake_path) if os.path.isfile(os.path.join(fake_path, name))])

    return f"Real: {real_count}, Fake: {fake_count}"

def get_dataset_counts():
    """Gets the file counts for train, validation, and test sets."""
    train_counts = count_files_in_subdirs('train')
    valid_counts = count_files_in_subdirs('valid')
    test_counts = count_files_in_subdirs('test')
    return train_counts, valid_counts, test_counts

def train_model_gradio(model_variant, ed_pretrained, vae_pretrained, epochs, batch_size, run_test, use_fp16, progress=gr.Progress()):
    """Refactored training function for Gradio UI."""
    dir_path = './'

    if not (os.path.exists('train') and os.path.exists('valid')):
        yield "Error: 'train' and 'valid' directories not found. Please create them and populate them with 'real' and 'fake' subdirectories."
        return

    yield "Loading data..."
    progress(0, desc="Loading data...")
    try:
        dataloaders, dataset_sizes = load_data(dir_path, int(batch_size))
        yield "Data loaded."
    except Exception as e:
        yield f"Error loading data: {e}. Please ensure the dataset is structured correctly."
        return

    models = []
    optimizers = []

    if model_variant in ["AE", "AE & VAE"]:
        yield "Initializing AE model..."
        model_ed = GenConViTED(config)
        optimizer_ed = optim.Adam(model_ed.parameters(), lr=float(config["learning_rate"]), weight_decay=float(config["weight_decay"]))
        if ed_pretrained and ed_pretrained != "No weights found":
            try:
                model_ed, optimizer_ed, _, _ = load_checkpoint(model_ed, optimizer_ed, filename=os.path.join("weight", ed_pretrained))
            except Exception as e:
                yield f"Error loading ED checkpoint: {e}"
        models.append(("ed", model_ed, optimizer_ed))

    if model_variant in ["VAE", "AE & VAE"]:
        yield "Initializing VAE model..."
        model_vae = GenConViTVAE(config)
        optimizer_vae = optim.Adam(model_vae.parameters(), lr=float(config["learning_rate"]), weight_decay=float(config["weight_decay"]))
        if vae_pretrained and vae_pretrained != "No weights found":
            try:
                model_vae, optimizer_vae, _, _ = load_checkpoint(model_vae, optimizer_vae, filename=os.path.join("weight", vae_pretrained))
            except Exception as e:
                 yield f"Error loading VAE checkpoint: {e}"
        models.append(("vae", model_vae, optimizer_vae))

    for mod, model, optimizer in models:
        yield f"Starting training for {mod.upper()} model..."

        criterion = nn.CrossEntropyLoss().to(device)
        mse = nn.MSELoss()
        scheduler = lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.1)
        model.to(device)

        train_loss, train_acc, valid_loss, valid_acc = [], [], [], []

        train_func, valid_func = None, None
        if mod == 'ed':
            from train.train_ed import train as train_func_ed, valid as valid_func_ed
            train_func, valid_func = train_func_ed, valid_func_ed
        else:
            from train.train_vae import train as train_func_vae, valid as valid_func_vae
            train_func, valid_func = train_func_vae, valid_func_vae

        for epoch in range(int(epochs)):
            epoch_desc = f"Epoch {epoch+1}/{int(epochs)} ({mod.upper()})"
            progress(epoch / int(epochs), desc=epoch_desc)

            yield f"{epoch_desc} - Training..."
            epoch_loss, epoch_acc = 0,0
            try:
                 train_loss, train_acc, epoch_loss = train_func(model, device, dataloaders["train"], criterion, optimizer, epoch, train_loss, train_acc, mse)
            except Exception as e:
                yield f"Error during training: {e}"
                break

            yield f"{epoch_desc} - Validation..."
            try:
                valid_loss, valid_acc = valid_func(model, device, dataloaders["validation"], criterion, epoch, valid_loss, valid_acc, mse)
                yield f"Epoch {epoch+1} complete for {mod.upper()}. Validation Loss: {valid_loss[-1]:.4f}, Validation Acc: {valid_acc[-1]:.4f}"
            except Exception as e:
                yield f"Error during validation: {e}"
                break

            scheduler.step()

        yield f"Training complete for {mod.upper()}. Saving model..."
        progress(1, desc=f"Saving {mod.upper()} model...")

        file_path = os.path.join("weight", f'genconvit_{mod}_{time.strftime("%b_%d_%Y_%H_%M_%S", time.localtime())}')
        with open(f"{file_path}.pkl", "wb") as f:
            pickle.dump([train_loss, train_acc, valid_loss, valid_acc], f)

        state = {
            "epoch": epochs, "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict(), "min_loss": epoch_loss,
        }
        weight_filename = f"{file_path}.pth"
        torch.save(state, weight_filename)
        yield f"Model saved to {weight_filename}"

        if run_test:
            yield f"Running test for {mod.upper()} model..."
            # test() function from train.py needs to be refactored to be callable here
            pass

    yield "All training processes finished."


def predict_video_gradio(video_path, ed_weight, vae_weight, num_frames, use_fp16, progress=gr.Progress()):
    """Refactored prediction function for Gradio UI."""
    if not video_path:
        return "Please upload a video.", "", "", ""

    net_type = None
    ed_weight_path, vae_weight_path = None, None

    if ed_weight and ed_weight != "No weights found":
        ed_weight_path = os.path.join("weight", ed_weight)
    if vae_weight and vae_weight != "No weights found":
        vae_weight_path = os.path.join("weight", vae_weight)

    if ed_weight_path and vae_weight_path:
        net_type = 'genconvit'
    elif ed_weight_path:
        net_type = 'ed'
    elif vae_weight_path:
        net_type = 'vae'
    else:
        return "Status: Error", "Please select at least one model weight.", ""

    yield "Status: Loading model...", "", ""
    progress(0.1, desc="Loading model...")
    try:
        model = load_genconvit(config, net_type, ed_weight_path, vae_weight_path, use_fp16)
    except Exception as e:
        return f"Status: Error loading model", f"Details: {e}", ""

    yield "Status: Model loaded. Extracting faces...", "", ""
    progress(0.3, desc="Extracting faces...")

    try:
        faces = df_face(video_path, int(num_frames))
        if len(faces) == 0:
            return "Status: Error", "No faces detected in the video.", ""
    except Exception as e:
        return "Status: Error during face extraction", f"Details: {e}. Is dlib installed correctly?", ""

    yield f"Status: {len(faces)} face(s) detected. Running prediction...", "", ""
    progress(0.8, desc="Running prediction...")

    try:
        y, y_val = pred_vid(faces, model)
        label = real_or_fake(y)
        score = y_val if label == "REAL" else 1 - y_val
        confidence_str = f"{score*100:.2f}%"

        progress(1, desc="Prediction complete")
        return f"Status: Prediction complete.", label, confidence_str
    except Exception as e:
        return "Status: Error during prediction", f"Details: {e}", ""