File size: 5,693 Bytes
68f9b9e
 
9b06d51
68f9b9e
 
 
 
9b06d51
 
 
 
 
b197185
 
 
 
 
 
 
 
9b06d51
 
 
b197185
68f9b9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b06d51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68f9b9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import os
import sys
from types import ModuleType

# Set Keras backend to torch for Python 3.14 compatibility
os.environ["KERAS_BACKEND"] = "torch"

# --- Monkey-patch for 'fer' library (to avoid TensorFlow requirement) ---
try:
    import keras
    tf = ModuleType("tensorflow")
    sys.modules["tensorflow"] = tf
    sys.modules["tensorflow.keras"] = keras
    tf.keras = keras
    import keras.models
    import keras.layers
    sys.modules["tensorflow.keras.models"] = keras.models
    sys.modules["tensorflow.keras.layers"] = keras.layers
    tf.keras.models = keras.models
    tf.keras.layers = keras.layers
except Exception as e:
    print(f"DEBUG: Monkey-patch failed: {e}")


import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import cv2
import numpy as np
from PIL import Image

# Add backend to sys.path
base_dir = os.path.dirname(os.path.abspath(__file__))
backend_dir = os.path.join(base_dir, "backend")
sys.path.insert(0, backend_dir)

from app.database import SessionLocal
from app.models import EmotionLog
from fer import FER

# 1. Simple Neural Network to "learn" new emotions
class EmotionClassifier(nn.Module):
    def __init__(self, input_size, num_classes):
        super(EmotionClassifier, self).__init__()
        self.fc1 = nn.Linear(input_size, 64)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(64, num_classes)
    
    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

def train():
    db = SessionLocal()
    detector = FER(mtcnn=False)
    
    # Get all logs that have a corrected emotion
    logs = db.query(EmotionLog).filter(EmotionLog.corrected_emotion != None).all()
    
    # Check unique emotions count
    all_emotions = sorted(list(set([log.corrected_emotion for log in logs])))
    
    if len(all_emotions) < 2:
        print(f"FAILED: Not enough variety! You only have one corrected emotion: {all_emotions}. Correct at least two different images with different emotions to train the custom brain.")
        sys.exit(1)

    if len(logs) < 3:
        print(f"FAILED: Not enough data! You only have {len(logs)} corrections. Please correct at least 3 images (from different photos) to help the AI learn.")
        sys.exit(1)

    # Prepare labels and mappings
    emotion_to_idx = {emo: i for i, emo in enumerate(all_emotions)}
    idx_to_emotion = {i: emo for emo, i in emotion_to_idx.items()}
    
    X = []
    y = []

    print(f"Preparing data for emotions: {all_emotions}")

    for log in logs:
        # Load the image
        img_path = os.path.join(backend_dir, "data", log.image_path)
        if not os.path.exists(img_path): 
            print(f"Skipping missing image: {img_path}")
            continue
        
        try:
            img = cv2.imread(img_path)
            if img is None:
                print(f"Skipping unreadable image: {img_path}")
                continue
            # Use FER to detect face and get the emotion probabilities (features)
            results = detector.detect_emotions(img)
            if results:
                # We take the 7 base emotion probabilities as features
                # Sort keys to ensure consistent feature order matching inference
                emo_dict = results[0]["emotions"]
                features = [emo_dict[k] for k in sorted(emo_dict.keys())]
                X.append(features)
                y.append(emotion_to_idx[log.corrected_emotion])
            else:
                # If FER fails on this specific image, we use neutral features
                print(f"Warning: FER could not find face in {img_path}, skipping.")
        except Exception as e:
            print(f"Error processing {img_path}: {e}")
            continue

    if not X:
        print("FAILED: Could not extract features from any of your corrected images.")
        sys.exit(1)

    # Print summary of samples per emotion
    print("\nTraining Data Summary:")
    from collections import Counter
    counts = Counter([idx_to_emotion[idx] for idx in y])
    for emo, count in sorted(counts.items()):
        print(f" - {emo}: {count} samples")
    
    # Check for missing emotions
    missing = set(all_emotions) - set(counts.keys())
    if missing:
        print(f"\nWARNING: The following emotions have 0 valid samples (face detection failed): {list(missing)}")
        print("These emotions will NOT be recognized by the custom brain. Try using clearer photos.")
    
    print(f"\nTotal training samples: {len(y)}\n")

    X = torch.tensor(X, dtype=torch.float32)
    y = torch.tensor(y, dtype=torch.long)

    # 2. Train the model
    model = EmotionClassifier(input_size=7, num_classes=len(all_emotions))
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.01)

    print("Training the custom brain...")
    # Increase epochs for small dataset to ensure convergence
    for epoch in range(200):
        optimizer.zero_grad()
        outputs = model.forward(X)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()
        if (epoch+1) % 50 == 0:
            print(f"Epoch [{epoch+1}/200], Loss: {loss.item():.4f}")

    # 3. Save the custom model and the label mapping
    save_path = os.path.join(backend_dir, "app/models/custom_ai")
    os.makedirs(save_path, exist_ok=True)
    
    torch.save(model.state_dict(), os.path.join(save_path, "custom_weights.pth"))
    
    import json
    with open(os.path.join(save_path, "labels.json"), "w") as f:
        json.dump(idx_to_emotion, f)
        
    print(f"SUCCESS: Custom brain trained and saved. It now knows: {all_emotions}")

if __name__ == "__main__":
    train()