|
|
| """
|
| Complete GCN Training Pipeline for Speech Bubble to Speaker Association
|
| Fixed version that handles the dataset format correctly and resolves training issues.
|
| """
|
|
|
| import json
|
| import torch
|
| import numpy as np
|
| import random
|
| from torch_geometric.data import HeteroData, Batch
|
| import torch.nn as nn
|
| from scipy.optimize import linear_sum_assignment
|
| from typing import Dict, List, Any, Optional, Tuple
|
|
|
| import os
|
|
|
| from pathlib import Path
|
| CHECKPOINT_DIR = Path("checkpoints")
|
| CHECKPOINT_DIR.mkdir(exist_ok=True)
|
|
|
| def save_checkpoint(model: torch.nn.Module,
|
| epoch: int,
|
| loss: float,
|
| path: Path = CHECKPOINT_DIR / "assoc_gcn.pt") -> None:
|
| """
|
| Persist full training state so you can resume fine-tuning later.
|
| """
|
| path = Path(path)
|
| torch.save({
|
| "epoch": epoch,
|
| "loss": loss,
|
| "model_state": model.state_dict()
|
| }, path)
|
| print(f"✅ Model checkpoint saved to {path.resolve()}")
|
|
|
| class DatasetLoader:
|
| """Handles loading and preprocessing of the converted GCN dataset"""
|
|
|
| @staticmethod
|
| def load_converted_dataset(json_path: str) -> List[HeteroData]:
|
| """Load the converted GCN dataset and create PyTorch Geometric HeteroData objects"""
|
| with open(json_path, 'r', encoding='utf-8') as f:
|
| data = json.load(f)
|
|
|
| dataset = []
|
| panels = data.get('panels', [])
|
|
|
| for panel in panels:
|
| het_data = DatasetLoader.create_hetero_data_from_panel(panel)
|
| if het_data is not None:
|
| dataset.append(het_data)
|
|
|
| print(f"Loaded {len(dataset)} panels from {json_path}")
|
| return dataset
|
|
|
| @staticmethod
|
| def create_hetero_data_from_panel(panel: Dict) -> Optional[HeteroData]:
|
| """Convert a single panel from the converted dataset into HeteroData format"""
|
| bubbles = panel.get('bubbles', [])
|
| faces = panel.get('faces', [])
|
| links = panel.get('links', [])
|
|
|
| if len(bubbles) == 0 or len(faces) == 0:
|
| return None
|
|
|
| W, H = panel['width'], panel['height']
|
|
|
|
|
| bubble_features = []
|
| for bubble in bubbles:
|
| x1, y1, x2, y2 = bubble['bbox']
|
| cx, cy = (x1 + x2) / (2 * W), (y1 + y2) / (2 * H)
|
| w, h = (x2 - x1) / W, (y2 - y1) / H
|
| area = w * h
|
| aspect = w / h if h > 0 else 1.0
|
| bubble_features.append([cx, cy, w, h, area, aspect])
|
|
|
| face_features = []
|
| for face in faces:
|
| x1, y1, x2, y2 = face['bbox']
|
| cx, cy = (x1 + x2) / (2 * W), (y1 + y2) / (2 * H)
|
| w, h = (x2 - x1) / W, (y2 - y1) / H
|
| area = w * h
|
| aspect = w / h if h > 0 else 1.0
|
| face_features.append([cx, cy, w, h, area, aspect])
|
|
|
|
|
| edge_indices, edge_features, edge_labels = [], [], []
|
|
|
|
|
| bubble_id_to_idx = {bubble['bubble_id']: i for i, bubble in enumerate(bubbles)}
|
| face_id_to_idx = {face['face_id']: i for i, face in enumerate(faces)}
|
|
|
|
|
| gt_links = {}
|
| for link in links:
|
| if link['bubble_id'] in bubble_id_to_idx and link['face_id'] in face_id_to_idx:
|
| bubble_idx = bubble_id_to_idx[link['bubble_id']]
|
| face_idx = face_id_to_idx[link['face_id']]
|
| gt_links[(bubble_idx, face_idx)] = 1
|
|
|
|
|
| for i, bubble in enumerate(bubbles):
|
| for j, face in enumerate(faces):
|
|
|
| b_x1, b_y1, b_x2, b_y2 = bubble['bbox']
|
| f_x1, f_y1, f_x2, f_y2 = face['bbox']
|
|
|
| b_cx, b_cy = (b_x1 + b_x2) / (2 * W), (b_y1 + b_y2) / (2 * H)
|
| f_cx, f_cy = (f_x1 + f_x2) / (2 * W), (f_y1 + f_y2) / (2 * H)
|
|
|
| dx, dy = b_cx - f_cx, b_cy - f_cy
|
| dist = (dx**2 + dy**2)**0.5
|
|
|
|
|
| xx1, yy1 = max(b_x1, f_x1), max(b_y1, f_y1)
|
| xx2, yy2 = min(b_x2, f_x2), min(b_y2, f_y2)
|
| inter = max(0, xx2 - xx1) * max(0, yy2 - yy1)
|
| union = (b_x2 - b_x1) * (b_y2 - b_y1) + (f_x2 - f_x1) * (f_y2 - f_y1) - inter
|
| iou = inter / union if union > 0 else 0
|
|
|
| edge_indices.append([i, j])
|
| edge_features.append([dx, dy, dist, iou])
|
| edge_labels.append(1.0 if (i, j) in gt_links else 0.0)
|
|
|
| if len(edge_indices) == 0:
|
| return None
|
|
|
|
|
| data = HeteroData()
|
| data['bubble'].x = torch.tensor(bubble_features, dtype=torch.float)
|
| data['face'].x = torch.tensor(face_features, dtype=torch.float)
|
| data['bubble', 'to', 'face'].edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
|
| data['bubble', 'to', 'face'].edge_attr = torch.tensor(edge_features, dtype=torch.float)
|
| data['bubble', 'to', 'face'].edge_label = torch.tensor(edge_labels, dtype=torch.float)
|
|
|
|
|
| data.panel_id = panel['panel_id']
|
| data.width = W
|
| data.height = H
|
|
|
| return data
|
|
|
|
|
| class AssocGCN(nn.Module):
|
| """Graph Convolutional Network for Speech Bubble to Speaker Association"""
|
|
|
| def __init__(self, in_feats: int = 6, hid: int = 128):
|
| super().__init__()
|
| self.node_encoder = nn.Sequential(
|
| nn.Linear(in_feats, hid),
|
| nn.ReLU(),
|
| nn.Linear(hid, hid)
|
| )
|
|
|
|
|
| self.conv1 = nn.Sequential(
|
| nn.Linear(hid * 2 + 4, hid),
|
| nn.ReLU(),
|
| nn.Linear(hid, hid)
|
| )
|
|
|
| self.conv2 = nn.Sequential(
|
| nn.Linear(hid * 2 + 4, hid),
|
| nn.ReLU(),
|
| nn.Linear(hid, hid)
|
| )
|
|
|
| self.conv3 = nn.Sequential(
|
| nn.Linear(hid * 2 + 4, hid),
|
| nn.ReLU(),
|
| nn.Linear(hid, hid)
|
| )
|
|
|
|
|
| self.edge_mlp = nn.Sequential(
|
| nn.Linear(2 * hid + 4, hid),
|
| nn.ReLU(),
|
| nn.Dropout(0.1),
|
| nn.Linear(hid, 1)
|
| )
|
|
|
| def forward(self, data):
|
| bubble_x = self.node_encoder(data['bubble'].x)
|
| face_x = self.node_encoder(data['face'].x)
|
|
|
| edge_index = data['bubble', 'to', 'face'].edge_index
|
| edge_attr = data['bubble', 'to', 'face'].edge_attr
|
| src_idx, dst_idx = edge_index[0], edge_index[1]
|
|
|
|
|
| for conv in [self.conv1, self.conv2, self.conv3]:
|
| src_features = bubble_x[src_idx]
|
| dst_features = face_x[dst_idx]
|
|
|
| edge_input = torch.cat([src_features, dst_features, edge_attr], dim=1)
|
| edge_updates = conv(edge_input)
|
|
|
|
|
| bubble_updates = torch.zeros_like(bubble_x)
|
| face_updates = torch.zeros_like(face_x)
|
|
|
| for i in range(len(src_idx)):
|
| s, d = src_idx[i].item(), dst_idx[i].item()
|
| bubble_updates[s] += edge_updates[i]
|
| face_updates[d] += edge_updates[i]
|
|
|
|
|
| bubble_degrees = torch.bincount(src_idx, minlength=bubble_x.size(0)).float().clamp(min=1)
|
| face_degrees = torch.bincount(dst_idx, minlength=face_x.size(0)).float().clamp(min=1)
|
|
|
| bubble_updates = bubble_updates / bubble_degrees.unsqueeze(1)
|
| face_updates = face_updates / face_degrees.unsqueeze(1)
|
|
|
|
|
| bubble_x = bubble_x + bubble_updates
|
| face_x = face_x + face_updates
|
|
|
|
|
| src_final = bubble_x[src_idx]
|
| dst_final = face_x[dst_idx]
|
| edge_input = torch.cat([src_final, dst_final, edge_attr], dim=1)
|
| logits = self.edge_mlp(edge_input).squeeze(-1)
|
|
|
| return logits
|
|
|
|
|
| def hungarian_matching(scores: torch.Tensor, src_indices, dst_indices):
|
| """Apply Hungarian algorithm for optimal bipartite matching"""
|
| if len(scores) == 0:
|
| return {}
|
|
|
| num_bubbles = src_indices.max().item() + 1 if len(src_indices) > 0 else 0
|
| num_faces = dst_indices.max().item() + 1 if len(dst_indices) > 0 else 0
|
|
|
| cost_matrix = np.full((num_bubbles, num_faces), 1e6, dtype=np.float32)
|
|
|
| scores_np = scores.detach().cpu().sigmoid().numpy()
|
| for i, (s, d, score) in enumerate(zip(src_indices.cpu(), dst_indices.cpu(), scores_np)):
|
| cost_matrix[s, d] = -score
|
|
|
| row_indices, col_indices = linear_sum_assignment(cost_matrix)
|
|
|
| mapping = {}
|
| for r, c in zip(row_indices, col_indices):
|
| if cost_matrix[r, c] < 0:
|
| mapping[int(r)] = int(c)
|
|
|
| return mapping
|
|
|
|
|
| def train_gcn(dataset: List[HeteroData], epochs: int = 200, batch_size: int = 16, lr: float = 1e-4):
|
| """Train the GCN model on the dataset"""
|
| if len(dataset) == 0:
|
| raise ValueError("Dataset is empty!")
|
|
|
| print(f"Training on {len(dataset)} panels...")
|
|
|
| device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| model = AssocGCN().to(device)
|
| optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
|
|
|
|
|
| total_positive = sum(data['bubble', 'to', 'face'].edge_label.sum().item() for data in dataset)
|
| total_edges = sum(len(data['bubble', 'to', 'face'].edge_label) for data in dataset)
|
| pos_weight = (total_edges - total_positive) / total_positive if total_positive > 0 else 9.0
|
|
|
| print(f"Positive edges: {total_positive}/{total_edges} ({100*total_positive/total_edges:.1f}%)")
|
| print(f"Using pos_weight: {pos_weight:.2f}")
|
|
|
| loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device=device))
|
|
|
| model.train()
|
| best_loss = float("inf")
|
| for epoch in range(epochs):
|
| total_correct = 0
|
| total_samples = 0
|
| total_tp = 0
|
| total_fp = 0
|
| total_fn = 0
|
|
|
| shuffled_dataset = dataset.copy()
|
| random.shuffle(shuffled_dataset)
|
|
|
| total_loss = 0.0
|
| num_batches = (len(shuffled_dataset) + batch_size - 1) // batch_size
|
|
|
| for batch_idx in range(num_batches):
|
| start_idx = batch_idx * batch_size
|
| end_idx = min(start_idx + batch_size, len(shuffled_dataset))
|
| batch_data = shuffled_dataset[start_idx:end_idx]
|
|
|
|
|
| batch = Batch.from_data_list(batch_data).to(device)
|
|
|
|
|
| logits = model(batch)
|
| labels = batch['bubble', 'to', 'face'].edge_label
|
|
|
|
|
| loss = loss_fn(logits, labels)
|
|
|
|
|
| optimizer.zero_grad()
|
| loss.backward()
|
| optimizer.step()
|
|
|
| total_loss += loss.item() * len(batch_data)
|
|
|
|
|
| probs = torch.sigmoid(logits)
|
| preds = (probs > 0.5).float()
|
| correct = (preds == labels).sum().item()
|
| total_correct += correct
|
| total_samples += labels.numel()
|
|
|
| avg_loss = total_loss / len(shuffled_dataset)
|
| print(f"Epoch {epoch+1:02d}/{epochs}: Loss = {avg_loss:.4f}")
|
|
|
| avg_loss = total_loss / len(shuffled_dataset)
|
| accuracy = total_correct / total_samples
|
| if avg_loss < best_loss:
|
| best_loss = avg_loss
|
| save_checkpoint(model, epoch+1, best_loss)
|
|
|
|
|
| recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0.0
|
| precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0.0
|
| f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
|
|
|
| print(f"Epoch {epoch+1:02d}/{epochs}: Loss = {avg_loss:.4f}, Accuracy = {accuracy:.4f}, Recall = {recall:.4f}, F1 = {f1:.4f}")
|
|
|
|
|
| print("Training completed!")
|
| return model
|
|
|
|
|
| def infer_associations(model, data):
|
| """Infer speech bubble to speaker associations"""
|
| device = next(model.parameters()).device
|
| data = data.to(device)
|
|
|
| model.eval()
|
| with torch.no_grad():
|
| logits = model(data)
|
| src, dst = data['bubble', 'to', 'face'].edge_index
|
| mapping = hungarian_matching(logits, src, dst)
|
|
|
| return mapping
|
|
|
|
|
|
|
| def train_speaker(config):
|
|
|
| dataset = []
|
| for panel_data_file in os.listdir(os.path.join(config["root"]+"panel_data/")):
|
| try:
|
| print(panel_data_file)
|
|
|
| dataset += DatasetLoader.load_converted_dataset(os.path.join(config["root"]+"panel_data/",panel_data_file))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| except FileNotFoundError:
|
| print("Error: ./output.json not found!")
|
| print("Please ensure your converted dataset file exists.")
|
| except Exception as e:
|
| print(f"Error: {e}")
|
| print("Please check your dataset format and file paths.")
|
|
|
| model = train_gcn(dataset, epochs=30, batch_size=16) |