Sonofica / utils /train_speaker.py
janmayjay's picture
Add application file
39a7537
#!/usr/bin/env python3
"""
Complete GCN Training Pipeline for Speech Bubble to Speaker Association
Fixed version that handles the dataset format correctly and resolves training issues.
"""
import json
import torch
import numpy as np
import random
from torch_geometric.data import HeteroData, Batch
import torch.nn as nn
from scipy.optimize import linear_sum_assignment
from typing import Dict, List, Any, Optional, Tuple
# from utils.utilities import save_checkpoint
import os
from pathlib import Path
CHECKPOINT_DIR = Path("checkpoints")
CHECKPOINT_DIR.mkdir(exist_ok=True)
def save_checkpoint(model: torch.nn.Module,
epoch: int,
loss: float,
path: Path = CHECKPOINT_DIR / "assoc_gcn.pt") -> None:
"""
Persist full training state so you can resume fine-tuning later.
"""
path = Path(path)
torch.save({
"epoch": epoch,
"loss": loss,
"model_state": model.state_dict()
}, path)
print(f"✅ Model checkpoint saved to {path.resolve()}")
class DatasetLoader:
"""Handles loading and preprocessing of the converted GCN dataset"""
@staticmethod
def load_converted_dataset(json_path: str) -> List[HeteroData]:
"""Load the converted GCN dataset and create PyTorch Geometric HeteroData objects"""
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
dataset = []
panels = data.get('panels', [])
for panel in panels:
het_data = DatasetLoader.create_hetero_data_from_panel(panel)
if het_data is not None:
dataset.append(het_data)
print(f"Loaded {len(dataset)} panels from {json_path}")
return dataset
@staticmethod
def create_hetero_data_from_panel(panel: Dict) -> Optional[HeteroData]:
"""Convert a single panel from the converted dataset into HeteroData format"""
bubbles = panel.get('bubbles', [])
faces = panel.get('faces', [])
links = panel.get('links', [])
if len(bubbles) == 0 or len(faces) == 0:
return None
W, H = panel['width'], panel['height']
# Create node features
bubble_features = []
for bubble in bubbles:
x1, y1, x2, y2 = bubble['bbox']
cx, cy = (x1 + x2) / (2 * W), (y1 + y2) / (2 * H)
w, h = (x2 - x1) / W, (y2 - y1) / H
area = w * h
aspect = w / h if h > 0 else 1.0
bubble_features.append([cx, cy, w, h, area, aspect])
face_features = []
for face in faces:
x1, y1, x2, y2 = face['bbox']
cx, cy = (x1 + x2) / (2 * W), (y1 + y2) / (2 * H)
w, h = (x2 - x1) / W, (y2 - y1) / H
area = w * h
aspect = w / h if h > 0 else 1.0
face_features.append([cx, cy, w, h, area, aspect])
# Create edge indices and features
edge_indices, edge_features, edge_labels = [], [], []
# Create mappings
bubble_id_to_idx = {bubble['bubble_id']: i for i, bubble in enumerate(bubbles)}
face_id_to_idx = {face['face_id']: i for i, face in enumerate(faces)}
# Create ground truth mapping
gt_links = {}
for link in links:
if link['bubble_id'] in bubble_id_to_idx and link['face_id'] in face_id_to_idx:
bubble_idx = bubble_id_to_idx[link['bubble_id']]
face_idx = face_id_to_idx[link['face_id']]
gt_links[(bubble_idx, face_idx)] = 1
# Create all possible bubble-face edges
for i, bubble in enumerate(bubbles):
for j, face in enumerate(faces):
# Calculate edge features
b_x1, b_y1, b_x2, b_y2 = bubble['bbox']
f_x1, f_y1, f_x2, f_y2 = face['bbox']
b_cx, b_cy = (b_x1 + b_x2) / (2 * W), (b_y1 + b_y2) / (2 * H)
f_cx, f_cy = (f_x1 + f_x2) / (2 * W), (f_y1 + f_y2) / (2 * H)
dx, dy = b_cx - f_cx, b_cy - f_cy
dist = (dx**2 + dy**2)**0.5
# Calculate IoU
xx1, yy1 = max(b_x1, f_x1), max(b_y1, f_y1)
xx2, yy2 = min(b_x2, f_x2), min(b_y2, f_y2)
inter = max(0, xx2 - xx1) * max(0, yy2 - yy1)
union = (b_x2 - b_x1) * (b_y2 - b_y1) + (f_x2 - f_x1) * (f_y2 - f_y1) - inter
iou = inter / union if union > 0 else 0
edge_indices.append([i, j])
edge_features.append([dx, dy, dist, iou])
edge_labels.append(1.0 if (i, j) in gt_links else 0.0)
if len(edge_indices) == 0:
return None
# Create HeteroData object
data = HeteroData()
data['bubble'].x = torch.tensor(bubble_features, dtype=torch.float)
data['face'].x = torch.tensor(face_features, dtype=torch.float)
data['bubble', 'to', 'face'].edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
data['bubble', 'to', 'face'].edge_attr = torch.tensor(edge_features, dtype=torch.float)
data['bubble', 'to', 'face'].edge_label = torch.tensor(edge_labels, dtype=torch.float)
# Add metadata
data.panel_id = panel['panel_id']
data.width = W
data.height = H
return data
class AssocGCN(nn.Module):
"""Graph Convolutional Network for Speech Bubble to Speaker Association"""
def __init__(self, in_feats: int = 6, hid: int = 128):
super().__init__()
self.node_encoder = nn.Sequential(
nn.Linear(in_feats, hid),
nn.ReLU(),
nn.Linear(hid, hid)
)
# Message passing layers
self.conv1 = nn.Sequential(
nn.Linear(hid * 2 + 4, hid), # node features + edge features
nn.ReLU(),
nn.Linear(hid, hid)
)
self.conv2 = nn.Sequential(
nn.Linear(hid * 2 + 4, hid),
nn.ReLU(),
nn.Linear(hid, hid)
)
self.conv3 = nn.Sequential(
nn.Linear(hid * 2 + 4, hid),
nn.ReLU(),
nn.Linear(hid, hid)
)
# Edge classifier
self.edge_mlp = nn.Sequential(
nn.Linear(2 * hid + 4, hid),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hid, 1)
)
def forward(self, data):
bubble_x = self.node_encoder(data['bubble'].x)
face_x = self.node_encoder(data['face'].x)
edge_index = data['bubble', 'to', 'face'].edge_index
edge_attr = data['bubble', 'to', 'face'].edge_attr
src_idx, dst_idx = edge_index[0], edge_index[1]
# Apply message passing
for conv in [self.conv1, self.conv2, self.conv3]:
src_features = bubble_x[src_idx]
dst_features = face_x[dst_idx]
edge_input = torch.cat([src_features, dst_features, edge_attr], dim=1)
edge_updates = conv(edge_input)
# Update node features (simplified aggregation)
bubble_updates = torch.zeros_like(bubble_x)
face_updates = torch.zeros_like(face_x)
for i in range(len(src_idx)):
s, d = src_idx[i].item(), dst_idx[i].item()
bubble_updates[s] += edge_updates[i]
face_updates[d] += edge_updates[i]
# Normalize by degree
bubble_degrees = torch.bincount(src_idx, minlength=bubble_x.size(0)).float().clamp(min=1)
face_degrees = torch.bincount(dst_idx, minlength=face_x.size(0)).float().clamp(min=1)
bubble_updates = bubble_updates / bubble_degrees.unsqueeze(1)
face_updates = face_updates / face_degrees.unsqueeze(1)
# Residual connection
bubble_x = bubble_x + bubble_updates
face_x = face_x + face_updates
# Final edge prediction
src_final = bubble_x[src_idx]
dst_final = face_x[dst_idx]
edge_input = torch.cat([src_final, dst_final, edge_attr], dim=1)
logits = self.edge_mlp(edge_input).squeeze(-1)
return logits
def hungarian_matching(scores: torch.Tensor, src_indices, dst_indices):
"""Apply Hungarian algorithm for optimal bipartite matching"""
if len(scores) == 0:
return {}
num_bubbles = src_indices.max().item() + 1 if len(src_indices) > 0 else 0
num_faces = dst_indices.max().item() + 1 if len(dst_indices) > 0 else 0
cost_matrix = np.full((num_bubbles, num_faces), 1e6, dtype=np.float32)
scores_np = scores.detach().cpu().sigmoid().numpy()
for i, (s, d, score) in enumerate(zip(src_indices.cpu(), dst_indices.cpu(), scores_np)):
cost_matrix[s, d] = -score # Negative for minimization
row_indices, col_indices = linear_sum_assignment(cost_matrix)
mapping = {}
for r, c in zip(row_indices, col_indices):
if cost_matrix[r, c] < 0: # Valid assignment
mapping[int(r)] = int(c)
return mapping
def train_gcn(dataset: List[HeteroData], epochs: int = 200, batch_size: int = 16, lr: float = 1e-4):
"""Train the GCN model on the dataset"""
if len(dataset) == 0:
raise ValueError("Dataset is empty!")
print(f"Training on {len(dataset)} panels...")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = AssocGCN().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
# Calculate class weights
total_positive = sum(data['bubble', 'to', 'face'].edge_label.sum().item() for data in dataset)
total_edges = sum(len(data['bubble', 'to', 'face'].edge_label) for data in dataset)
pos_weight = (total_edges - total_positive) / total_positive if total_positive > 0 else 9.0
print(f"Positive edges: {total_positive}/{total_edges} ({100*total_positive/total_edges:.1f}%)")
print(f"Using pos_weight: {pos_weight:.2f}")
loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device=device))
model.train()
best_loss = float("inf")
for epoch in range(epochs):
total_correct = 0
total_samples = 0
total_tp = 0
total_fp = 0
total_fn = 0
# FIXED: Properly shuffle the dataset (it's a list, not a dict)
shuffled_dataset = dataset.copy()
random.shuffle(shuffled_dataset)
total_loss = 0.0
num_batches = (len(shuffled_dataset) + batch_size - 1) // batch_size
for batch_idx in range(num_batches):
start_idx = batch_idx * batch_size
end_idx = min(start_idx + batch_size, len(shuffled_dataset))
batch_data = shuffled_dataset[start_idx:end_idx]
# Create batch
batch = Batch.from_data_list(batch_data).to(device)
# Forward pass
logits = model(batch)
labels = batch['bubble', 'to', 'face'].edge_label
# Compute loss
loss = loss_fn(logits, labels)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item() * len(batch_data)
# Accuracy computation
probs = torch.sigmoid(logits)
preds = (probs > 0.5).float()
correct = (preds == labels).sum().item()
total_correct += correct
total_samples += labels.numel()
avg_loss = total_loss / len(shuffled_dataset)
print(f"Epoch {epoch+1:02d}/{epochs}: Loss = {avg_loss:.4f}")
avg_loss = total_loss / len(shuffled_dataset)
accuracy = total_correct / total_samples
if avg_loss < best_loss:
best_loss = avg_loss
save_checkpoint(model, epoch+1, best_loss) # epoch is 0-indexed
# Compute recall and F1
recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0.0
precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0.0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
print(f"Epoch {epoch+1:02d}/{epochs}: Loss = {avg_loss:.4f}, Accuracy = {accuracy:.4f}, Recall = {recall:.4f}, F1 = {f1:.4f}")
print("Training completed!")
return model
def infer_associations(model, data):
"""Infer speech bubble to speaker associations"""
device = next(model.parameters()).device
data = data.to(device)
model.eval()
with torch.no_grad():
logits = model(data)
src, dst = data['bubble', 'to', 'face'].edge_index
mapping = hungarian_matching(logits, src, dst)
return mapping
# Example usage and testing
def train_speaker(config):
# Test with sample data
dataset = []
for panel_data_file in os.listdir(os.path.join(config["root"]+"panel_data/")):
try:
print(panel_data_file)
# Load your converted dataset
dataset += DatasetLoader.load_converted_dataset(os.path.join(config["root"]+"panel_data/",panel_data_file))
# if len(dataset) == 0:
# print("No valid panels found in dataset!")
# else:
# # Train the model
# model = train_gcn(dataset, epochs=10, batch_size=16) # Reduced epochs for testing
# # Test inference on first panel
# test_data = dataset[0]
# print(test_data)
# mapping = infer_associations(model, test_data)
# print("\nInference Results:")
# for bubble_id, face_id in mapping.items():
# print(f"Bubble {bubble_id} → Face {face_id}")
except FileNotFoundError:
print("Error: ./output.json not found!")
print("Please ensure your converted dataset file exists.")
except Exception as e:
print(f"Error: {e}")
print("Please check your dataset format and file paths.")
model = train_gcn(dataset, epochs=30, batch_size=16) # Reduced epochs for testing