Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| import argparse | |
| import json | |
| import os | |
| import random | |
| import re | |
| import sys | |
| from collections import Counter | |
| from dataclasses import dataclass | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import torch.nn as nn | |
| from PIL import Image | |
| from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.preprocessing import LabelEncoder | |
| from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler | |
| from torchvision import transforms | |
| ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | |
| BACKEND_DIR = os.path.join(ROOT_DIR, "web_app_pro", "backend") | |
| if BACKEND_DIR not in sys.path: | |
| sys.path.append(BACKEND_DIR) | |
| from models import get_model # noqa: E402 | |
| class TrainConfig: | |
| data_csv: str | |
| image_dir: str | |
| output_dir: str | |
| epochs: int | |
| batch_size: int | |
| lr: float | |
| max_len: int | |
| min_freq: int | |
| seed: int | |
| use_weighted_sampler: bool | |
| def set_seed(seed: int): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| def clean_text(text: str) -> str: | |
| text = str(text or "").lower() | |
| text = re.sub(r"\b(xxxx|xx|x)\b", " ", text) | |
| text = re.sub(r"[^a-z0-9\s]", " ", text) | |
| text = re.sub(r"\s+", " ", text).strip() | |
| return text | |
| def build_text_field(df: pd.DataFrame) -> pd.Series: | |
| cols = [c for c in ["indication", "findings", "impression"] if c in df.columns] | |
| if not cols: | |
| return pd.Series([""] * len(df)) | |
| merged = df[cols].fillna("").agg(" ".join, axis=1) | |
| return merged.map(clean_text) | |
| def build_vocab(texts, min_freq: int = 2): | |
| counter = Counter() | |
| for text in texts: | |
| counter.update(text.split()) | |
| vocab = {"<pad>": 0, "<unk>": 1} | |
| for token, freq in counter.items(): | |
| if freq >= min_freq: | |
| vocab[token] = len(vocab) | |
| return vocab | |
| def encode_text(text: str, vocab: dict, max_len: int): | |
| ids = [vocab.get(tok, vocab["<unk>"]) for tok in text.split()][:max_len] | |
| ids += [vocab["<pad>"]] * (max_len - len(ids)) | |
| return ids | |
| class FusionDataset(Dataset): | |
| def __init__(self, df, image_dir, vocab, label_encoder, transform, max_len): | |
| self.df = df.reset_index(drop=True) | |
| self.image_dir = image_dir | |
| self.vocab = vocab | |
| self.label_encoder = label_encoder | |
| self.transform = transform | |
| self.max_len = max_len | |
| def __len__(self): | |
| return len(self.df) | |
| def __getitem__(self, idx): | |
| row = self.df.iloc[idx] | |
| img_path = os.path.join(self.image_dir, row["filename"]) | |
| img = Image.open(img_path).convert("RGB") | |
| img = self.transform(img) | |
| text_ids = encode_text(row["text_input"], self.vocab, self.max_len) | |
| text_tensor = torch.tensor(text_ids, dtype=torch.long) | |
| label_id = int(self.label_encoder.transform([row["label"]])[0]) | |
| label_tensor = torch.tensor(label_id, dtype=torch.long) | |
| return img, text_tensor, label_tensor | |
| def make_dataloaders(df, cfg: TrainConfig, vocab, label_encoder): | |
| train_df, temp_df = train_test_split( | |
| df, | |
| test_size=0.30, | |
| random_state=cfg.seed, | |
| stratify=df["label"], | |
| ) | |
| val_df, test_df = train_test_split( | |
| temp_df, | |
| test_size=0.50, | |
| random_state=cfg.seed, | |
| stratify=temp_df["label"], | |
| ) | |
| tfm = transforms.Compose( | |
| [ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
| ] | |
| ) | |
| train_ds = FusionDataset(train_df, cfg.image_dir, vocab, label_encoder, tfm, cfg.max_len) | |
| val_ds = FusionDataset(val_df, cfg.image_dir, vocab, label_encoder, tfm, cfg.max_len) | |
| test_ds = FusionDataset(test_df, cfg.image_dir, vocab, label_encoder, tfm, cfg.max_len) | |
| if cfg.use_weighted_sampler: | |
| class_counts = train_df["label"].value_counts() | |
| weights = train_df["label"].map(lambda x: 1.0 / class_counts[x]).values | |
| sampler = WeightedRandomSampler(weights=torch.tensor(weights, dtype=torch.double), num_samples=len(weights), replacement=True) | |
| train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, sampler=sampler) | |
| else: | |
| train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True) | |
| val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False) | |
| test_loader = DataLoader(test_ds, batch_size=cfg.batch_size, shuffle=False) | |
| return train_loader, val_loader, test_loader, train_df | |
| def evaluate(model, loader, device): | |
| model.eval() | |
| y_true, y_pred = [], [] | |
| with torch.no_grad(): | |
| for imgs, txt, y in loader: | |
| imgs = imgs.to(device) | |
| txt = txt.to(device) | |
| y = y.to(device) | |
| logits = model(imgs, txt) | |
| preds = torch.argmax(logits, dim=1) | |
| y_true.extend(y.cpu().numpy().tolist()) | |
| y_pred.extend(preds.cpu().numpy().tolist()) | |
| return { | |
| "accuracy": float(accuracy_score(y_true, y_pred)), | |
| "precision": float(precision_score(y_true, y_pred, average="weighted", zero_division=0)), | |
| "recall": float(recall_score(y_true, y_pred, average="weighted", zero_division=0)), | |
| "f1": float(f1_score(y_true, y_pred, average="weighted", zero_division=0)), | |
| "samples": len(y_true), | |
| } | |
| def train(cfg: TrainConfig): | |
| set_seed(cfg.seed) | |
| df = pd.read_csv(cfg.data_csv) | |
| expected = {"filename", "label"} | |
| if not expected.issubset(set(df.columns)): | |
| raise ValueError(f"{cfg.data_csv} must contain columns: {sorted(expected)}") | |
| df = df.copy() | |
| df["text_input"] = build_text_field(df) | |
| df = df[df["filename"].notna() & df["label"].notna()].copy() | |
| df["img_exists"] = df["filename"].map(lambda f: os.path.exists(os.path.join(cfg.image_dir, f))) | |
| missing = int((~df["img_exists"]).sum()) | |
| if missing > 0: | |
| print(f"Dropping {missing} rows with missing image files.") | |
| df = df[df["img_exists"]].copy() | |
| vocab = build_vocab(df["text_input"].tolist(), min_freq=cfg.min_freq) | |
| label_encoder = LabelEncoder() | |
| label_encoder.fit(df["label"].astype(str)) | |
| device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") | |
| model = get_model(vocab_size=len(vocab), num_classes=len(label_encoder.classes_), device=device) | |
| train_loader, val_loader, test_loader, train_df = make_dataloaders(df, cfg, vocab, label_encoder) | |
| # Weighted loss to improve minority classes. | |
| train_counts = train_df["label"].value_counts() | |
| class_weight_values = [] | |
| for cls in label_encoder.classes_: | |
| class_weight_values.append(1.0 / float(train_counts.get(cls, 1.0))) | |
| class_weights = torch.tensor(class_weight_values, dtype=torch.float32, device=device) | |
| criterion = nn.CrossEntropyLoss(weight=class_weights) | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr) | |
| best_val_f1 = -1.0 | |
| best_state = None | |
| epoch_loss, epoch_acc = [], [] | |
| for epoch in range(cfg.epochs): | |
| model.train() | |
| running_loss = 0.0 | |
| correct = 0 | |
| total = 0 | |
| for imgs, txt, y in train_loader: | |
| imgs = imgs.to(device) | |
| txt = txt.to(device) | |
| y = y.to(device) | |
| optimizer.zero_grad() | |
| logits = model(imgs, txt) | |
| loss = criterion(logits, y) | |
| loss.backward() | |
| optimizer.step() | |
| running_loss += loss.item() * y.size(0) | |
| preds = torch.argmax(logits, dim=1) | |
| correct += int((preds == y).sum().item()) | |
| total += int(y.size(0)) | |
| train_loss = running_loss / max(total, 1) | |
| train_acc = correct / max(total, 1) | |
| val_metrics = evaluate(model, val_loader, device) | |
| epoch_loss.append(train_loss) | |
| epoch_acc.append(train_acc) | |
| print( | |
| f"Epoch {epoch + 1}/{cfg.epochs} | " | |
| f"train_loss={train_loss:.4f} train_acc={train_acc:.4f} " | |
| f"val_f1={val_metrics['f1']:.4f}" | |
| ) | |
| if val_metrics["f1"] > best_val_f1: | |
| best_val_f1 = val_metrics["f1"] | |
| best_state = {k: v.cpu() for k, v in model.state_dict().items()} | |
| if best_state is not None: | |
| model.load_state_dict(best_state) | |
| test_metrics = evaluate(model, test_loader, device) | |
| os.makedirs(cfg.output_dir, exist_ok=True) | |
| model_path = os.path.join(cfg.output_dir, "medisim_diagnostic_model_retrained.pth") | |
| vocab_path = os.path.join(cfg.output_dir, "vocab_retrained.pth") | |
| encoder_path = os.path.join(cfg.output_dir, "label_encoder_retrained.pth") | |
| insights_path = os.path.join(cfg.output_dir, "retrained_fusion_insights.json") | |
| torch.save(model.state_dict(), model_path) | |
| torch.save(vocab, vocab_path) | |
| torch.save(label_encoder, encoder_path) | |
| insights = { | |
| "retrained_multimodal_fusion": { | |
| "summary": test_metrics, | |
| "epoch_loss": epoch_loss, | |
| "epoch_accuracy": epoch_acc, | |
| "class_names": list(label_encoder.classes_), | |
| "config": { | |
| "epochs": cfg.epochs, | |
| "batch_size": cfg.batch_size, | |
| "lr": cfg.lr, | |
| "max_len": cfg.max_len, | |
| "min_freq": cfg.min_freq, | |
| "seed": cfg.seed, | |
| "weighted_sampler": cfg.use_weighted_sampler, | |
| }, | |
| } | |
| } | |
| with open(insights_path, "w", encoding="utf-8") as f: | |
| json.dump(insights, f, indent=2) | |
| print("\nRetraining complete.") | |
| print(f"Model: {model_path}") | |
| print(f"Vocab: {vocab_path}") | |
| print(f"Label encoder: {encoder_path}") | |
| print(f"Insights: {insights_path}") | |
| print(f"Test metrics: {test_metrics}") | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Retrain MediSim diagnostic model with cleaned text and class balancing.") | |
| parser.add_argument("--data-csv", default=os.path.join(ROOT_DIR, "data", "processed_metadata.csv")) | |
| parser.add_argument("--image-dir", default=os.path.join(ROOT_DIR, "data", "images", "images_normalized")) | |
| parser.add_argument("--output-dir", default=os.path.join(ROOT_DIR, "data")) | |
| parser.add_argument("--epochs", type=int, default=8) | |
| parser.add_argument("--batch-size", type=int, default=16) | |
| parser.add_argument("--lr", type=float, default=3e-4) | |
| parser.add_argument("--max-len", type=int, default=60) | |
| parser.add_argument("--min-freq", type=int, default=2) | |
| parser.add_argument("--seed", type=int, default=42) | |
| parser.add_argument("--no-weighted-sampler", action="store_true") | |
| return parser.parse_args() | |
| def main(): | |
| args = parse_args() | |
| cfg = TrainConfig( | |
| data_csv=args.data_csv, | |
| image_dir=args.image_dir, | |
| output_dir=args.output_dir, | |
| epochs=args.epochs, | |
| batch_size=args.batch_size, | |
| lr=args.lr, | |
| max_len=args.max_len, | |
| min_freq=args.min_freq, | |
| seed=args.seed, | |
| use_weighted_sampler=not args.no_weighted_sampler, | |
| ) | |
| train(cfg) | |
| if __name__ == "__main__": | |
| main() | |