RCMTUNetV4-VLM — Brain Tumor Segmentation & Report Generation
Description
Multimodal pipeline for brain tumor segmentation and automated neuro-oncology report generation. Architecture: RCMTUNetV4 segmentation + RAG (40 WHO CNS 2021 chunks, FAISS) + LLaVA-Med report generation.
📂 Fichiers dans ce dépôt
| Fichier | Description |
|---|---|
rcmt_unet_v4_final.pth |
Poids segmentation RCMTUNetV4 |
pipeline.py |
Architecture complète (toutes classes) |
rag_who_chunks.json |
40 chunks RAG WHO CNS 2021 |
rag_faiss.index |
Index FAISS pré-calculé |
rag_embeddings.npy |
Embeddings numpy (backup) |
prompts.json |
Tous les prompts v3.2 + profils ACP |
config.json |
Configuration et métriques |
evaluation_metrics.json |
Résultats détaillés |
Performance (n=50, seed=42, UCSF-PDGM)
| Métrique | Score | vs Baseline |
|---|---|---|
| BERTScore-F | 0.814 | > MediVLM (0.616) ✅ |
| TBFact | 0.922 | > BTReport (0.353) ✅ |
| RadGraph-F1 | 0.871 | > AutoRG (0.380) ✅ |
| Anti-hallucination | 1.000 | UNIQUE ✅ |
| Cross-validation | 1.000 | UNIQUE ✅ |
| Global score | 0.852 | Classe A ✅ |
🚀 Usage — Chargement complet en 10 lignes
import torch, json, faiss, numpy as np
from huggingface_hub import hf_hub_download, snapshot_download
from sentence_transformers import SentenceTransformer
# 1. Télécharger tous les fichiers
local_dir = snapshot_download(repo_id="mayoula/rcmt-unet-v4-vlm")
# 2. Charger l'architecture
import sys; sys.path.insert(0, local_dir)
from pipeline import RCMTUNetV4
# 3. Charger les poids segmentation
seg_model = RCMTUNetV4(in_channels=4, out_channels=4, features=(24,48,96,192))
seg_model.load_state_dict(torch.load(f"{local_dir}/rcmt_unet_v4_final.pth", map_location="cpu"))
seg_model.eval()
# 4. Charger le RAG
with open(f"{local_dir}/rag_who_chunks.json") as f:
rag_data = json.load(f)
WHO_CHUNKS = rag_data["chunks"]
faiss_idx = faiss.read_index(f"{local_dir}/rag_faiss.index")
embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
# 5. Charger les prompts
with open(f"{local_dir}/prompts.json") as f:
prompts = json.load(f)
# 6. Charger le VLM (LLaVA-Med — non fine-tuné, rechargé depuis HF)
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration, BitsAndBytesConfig
bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True)
vlm_processor = LlavaNextProcessor.from_pretrained("microsoft/llava-med-v1.5-mistral-7b")
vlm_model = LlavaNextForConditionalGeneration.from_pretrained(
"microsoft/llava-med-v1.5-mistral-7b", quantization_config=bnb, device_map="auto")
# 7. Fonction RAG retrieve
def rag_retrieve(query, top_k=4):
emb = embedder.encode([query], normalize_embeddings=True)
D, I = faiss_idx.search(emb.astype(np.float32), top_k)
refs = [f"[REF-{r+1}] {WHO_CHUNKS[i]}" for r, (d, i) in enumerate(zip(D[0], I[0])) if d > 0.10]
return "\n".join(refs) if refs else "Standard glioma protocol (WHO CNS 2021)."
Dataset
Trained and evaluated on UCSF-PDGM (n=50 patients, seed=42).
- Downloads last month
- 9
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support