RCMTUNetV4-VLM — Brain Tumor Segmentation & Report Generation

Description

Multimodal pipeline for brain tumor segmentation and automated neuro-oncology report generation. Architecture: RCMTUNetV4 segmentation + RAG (40 WHO CNS 2021 chunks, FAISS) + LLaVA-Med report generation.

📂 Fichiers dans ce dépôt

Fichier Description
rcmt_unet_v4_final.pth Poids segmentation RCMTUNetV4
pipeline.py Architecture complète (toutes classes)
rag_who_chunks.json 40 chunks RAG WHO CNS 2021
rag_faiss.index Index FAISS pré-calculé
rag_embeddings.npy Embeddings numpy (backup)
prompts.json Tous les prompts v3.2 + profils ACP
config.json Configuration et métriques
evaluation_metrics.json Résultats détaillés

Performance (n=50, seed=42, UCSF-PDGM)

Métrique Score vs Baseline
BERTScore-F 0.814 > MediVLM (0.616) ✅
TBFact 0.922 > BTReport (0.353) ✅
RadGraph-F1 0.871 > AutoRG (0.380) ✅
Anti-hallucination 1.000 UNIQUE ✅
Cross-validation 1.000 UNIQUE ✅
Global score 0.852 Classe A ✅

🚀 Usage — Chargement complet en 10 lignes

import torch, json, faiss, numpy as np
from huggingface_hub import hf_hub_download, snapshot_download
from sentence_transformers import SentenceTransformer

# 1. Télécharger tous les fichiers
local_dir = snapshot_download(repo_id="mayoula/rcmt-unet-v4-vlm")

# 2. Charger l'architecture
import sys; sys.path.insert(0, local_dir)
from pipeline import RCMTUNetV4

# 3. Charger les poids segmentation
seg_model = RCMTUNetV4(in_channels=4, out_channels=4, features=(24,48,96,192))
seg_model.load_state_dict(torch.load(f"{local_dir}/rcmt_unet_v4_final.pth", map_location="cpu"))
seg_model.eval()

# 4. Charger le RAG
with open(f"{local_dir}/rag_who_chunks.json") as f:
    rag_data = json.load(f)
WHO_CHUNKS = rag_data["chunks"]
faiss_idx  = faiss.read_index(f"{local_dir}/rag_faiss.index")
embedder   = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

# 5. Charger les prompts
with open(f"{local_dir}/prompts.json") as f:
    prompts = json.load(f)

# 6. Charger le VLM (LLaVA-Med — non fine-tuné, rechargé depuis HF)
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration, BitsAndBytesConfig
bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16,
                          bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True)
vlm_processor = LlavaNextProcessor.from_pretrained("microsoft/llava-med-v1.5-mistral-7b")
vlm_model     = LlavaNextForConditionalGeneration.from_pretrained(
    "microsoft/llava-med-v1.5-mistral-7b", quantization_config=bnb, device_map="auto")

# 7. Fonction RAG retrieve
def rag_retrieve(query, top_k=4):
    emb  = embedder.encode([query], normalize_embeddings=True)
    D, I = faiss_idx.search(emb.astype(np.float32), top_k)
    refs = [f"[REF-{r+1}] {WHO_CHUNKS[i]}" for r, (d, i) in enumerate(zip(D[0], I[0])) if d > 0.10]
    return "\n".join(refs) if refs else "Standard glioma protocol (WHO CNS 2021)."

Dataset

Trained and evaluated on UCSF-PDGM (n=50 patients, seed=42).

Downloads last month
9
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support