Spaces:
Sleeping
Sleeping
| """ | |
| server.py β BioStack FastAPI Backend | |
| Fully aligned to inference.py: | |
| β Same CoAtNetEncoder : self.encoder + global_pool="avg" (NOT backbone) | |
| β Same VisionT5ForGRPO : class name, generate() method | |
| β Same image preprocess: raw /255.0, no ImageNet normalize | |
| β Same generation : do_sample=True, temperature=0.9, top_p=0.999 | |
| β Same checkpoint logic: GRPO checkpoint-* > SFT best_model.pt | |
| β Same reward : contrastive + rouge_l + negation_safety + hf_judge (0.25 each) | |
| β Same state-dict keys : img_encoder.encoder.* (no _remap needed) | |
| β HF Judge score exposed for both SFT and GRPO comparison | |
| """ | |
| import io, os, glob, traceback, gc, time | |
| from pathlib import Path | |
| from threading import Lock | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import timm | |
| import numpy as np | |
| from PIL import Image | |
| from fastapi import FastAPI, File, UploadFile, Form | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from transformers import T5ForConditionalGeneration, T5Tokenizer | |
| from huggingface_hub import hf_hub_download | |
| from rouge_score import rouge_scorer as rouge_scorer_lib | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Config β mirrors inference.py exactly | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| HF_REPO = "AE-Shree/Biostack-Xray-NeMoGym" | |
| SFT_FILE = "best_model.pt" | |
| GRPO_SUB = "checkpoint-300" | |
| GRPO_BIN = "pytorch_model.bin" | |
| GRPO_SAFE = "model.safetensors" | |
| MODEL_DIR = Path("models") | |
| MODEL_DIR.mkdir(exist_ok=True) | |
| IMAGE_SIZE = 224 | |
| HF_EMBED_MODEL = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext" | |
| HF_NLI_MODEL = "cross-encoder/nli-roberta-base" | |
| HF_JUDGE_DEVICE = "cpu" | |
| NEGATION_PAIRS = [ | |
| ("no pneumonia", "pneumonia"), | |
| ("no effusion", "effusion"), | |
| ("no consolidation", "consolidation"), | |
| ("no cardiomegaly", "cardiomegaly"), | |
| ("no opacity", "opacity"), | |
| ("no infiltrate", "infiltrate"), | |
| ("clear lungs", "opacity"), | |
| ("clear lungs", "consolidation"), | |
| ("normal", "abnormal"), | |
| ] | |
| NLI_LABELS = [ | |
| "accurate and complete radiology report", | |
| "inaccurate or incomplete radiology report", | |
| ] | |
| DEVICE = torch.device("cpu") | |
| print(f"Device: {DEVICE}") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Image preprocessing β IDENTICAL to inference.py | |
| # Raw /255.0, permute, no ImageNet normalization | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def preprocess(file_bytes: bytes) -> torch.Tensor: | |
| img = Image.open(io.BytesIO(file_bytes)).convert("RGB").resize((IMAGE_SIZE, IMAGE_SIZE)) | |
| arr = np.array(img, dtype=np.float32) / 255.0 | |
| tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0) | |
| return tensor.to(DEVICE) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Model Architecture β IDENTICAL to inference.py | |
| # CoAtNetEncoder uses self.encoder + global_pool="avg" | |
| # Class name is VisionT5ForGRPO (not VisionT5) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class CoAtNetEncoder(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.encoder = timm.create_model( | |
| "coatnet_1_rw_224", pretrained=False, num_classes=0, global_pool="avg" | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.encoder(x) | |
| class VisionT5ForGRPO(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.img_encoder = CoAtNetEncoder() | |
| self.t5 = T5ForConditionalGeneration.from_pretrained("t5-small") | |
| self.proj = nn.Linear(768, self.t5.config.d_model) | |
| def _encode_image(self, pixel_values: torch.Tensor): | |
| feats = self.proj(self.img_encoder(pixel_values)).unsqueeze(1) | |
| enc_out = self.t5.encoder(inputs_embeds=feats) | |
| enc_attn = torch.ones(feats.shape[:2], dtype=torch.long, device=feats.device) | |
| return enc_out, enc_attn | |
| def generate(self, pixel_values: torch.Tensor, **kwargs) -> torch.Tensor: | |
| enc_out, enc_attn = self._encode_image(pixel_values) | |
| return self.t5.generate( | |
| encoder_outputs=enc_out, | |
| attention_mask=enc_attn, | |
| **kwargs, | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Checkpoint helpers β aligned to inference.py logic | |
| # No _remap() needed: keys are img_encoder.encoder.* as saved | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _load_sd(path: str) -> dict: | |
| if path.endswith(".safetensors"): | |
| from safetensors.torch import load_file | |
| sd = load_file(path, device="cpu") | |
| else: | |
| sd = torch.load(path, map_location="cpu", weights_only=False) | |
| for wrap in ("state_dict", "model_state_dict", "model"): | |
| if wrap in sd and isinstance(sd[wrap], dict): | |
| sd = sd[wrap] | |
| break | |
| return sd | |
| def _ensure(filename, subfolder=None) -> str: | |
| fname = f"{subfolder}_{filename}" if subfolder else filename | |
| local = MODEL_DIR / fname | |
| if local.exists(): | |
| return str(local) | |
| kw = dict(repo_id=HF_REPO, filename=filename, local_dir=str(MODEL_DIR)) | |
| if subfolder: | |
| kw["subfolder"] = subfolder | |
| return hf_hub_download(**kw) | |
| def _build(path: str) -> VisionT5ForGRPO: | |
| m = VisionT5ForGRPO() | |
| m.load_state_dict(_load_sd(path), strict=True) # strict=True β same as inference.py | |
| m.eval() | |
| for p in m.parameters(): | |
| p.requires_grad_(False) | |
| return m | |
| def _find_grpo_checkpoint() -> str | None: | |
| """ | |
| Mirror of inference.py find_best_checkpoint(): | |
| scan MODEL_DIR for checkpoint-* folders, prefer latest. | |
| """ | |
| ckpt_dirs = glob.glob(str(MODEL_DIR / "checkpoint-*")) | |
| if ckpt_dirs: | |
| ckpt_dirs = sorted(ckpt_dirs, key=lambda x: int(x.split("checkpoint-")[-1])) | |
| latest = ckpt_dirs[-1] | |
| for fname in (GRPO_BIN, GRPO_SAFE): | |
| p = os.path.join(latest, fname) | |
| if os.path.exists(p): | |
| return p | |
| return None | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Load models at startup | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("\n" + "="*60) | |
| print("LOADING TOKENIZER") | |
| tokenizer = T5Tokenizer.from_pretrained( | |
| "t5-small", truncation_side="left", padding_side="left", legacy=True) | |
| print("OK") | |
| print("LOADING SFT") | |
| t0 = time.time() | |
| sft_model = _build(_ensure(SFT_FILE)) | |
| print(f"SFT OK ({time.time()-t0:.1f}s)") | |
| print("LOADING GRPO") | |
| t0 = time.time() | |
| try: | |
| grpo_path = _ensure(GRPO_BIN, GRPO_SUB) | |
| except Exception: | |
| grpo_path = _ensure(GRPO_SAFE, GRPO_SUB) | |
| grpo_model = _build(grpo_path) | |
| print(f"GRPO OK ({time.time()-t0:.1f}s)") | |
| gc.collect() | |
| print("="*60 + "\nVisionT5ForGRPO models ready β reward judges load on first use\n" + "="*60) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Report generation β IDENTICAL sampling params to inference.py | |
| # do_sample=True, temperature=0.9, top_p=0.999, repetition_penalty=1.2 | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _remove_numbers(text: str) -> str: | |
| """Remove all numbers from the generated text.""" | |
| import re | |
| # Remove digits and decimal numbers | |
| return re.sub(r'\d+\.?\d*', '', text).strip() | |
| def generate_report(model: VisionT5ForGRPO, pixel_values: torch.Tensor) -> str: | |
| with torch.no_grad(): | |
| out = model.generate( | |
| pixel_values=pixel_values, | |
| max_new_tokens=200, | |
| do_sample=True, | |
| temperature=0.9, | |
| top_p=0.999, | |
| repetition_penalty=1.2, | |
| ) | |
| raw_text = tokenizer.decode(out[0], skip_special_tokens=True).strip() | |
| return _remove_numbers(raw_text) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Lazy reward judge singletons | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _rouge_scorer_inst = None | |
| _embed_tok_inst = None | |
| _embed_model_inst = None | |
| _nli_pipe_inst = None | |
| _judge_lock = Lock() | |
| def _get_judges(): | |
| global _rouge_scorer_inst, _embed_tok_inst, _embed_model_inst, _nli_pipe_inst | |
| with _judge_lock: | |
| if _rouge_scorer_inst is None: | |
| print("Loading reward judges...") | |
| from transformers import AutoModel, AutoTokenizer, pipeline | |
| _rouge_scorer_inst = rouge_scorer_lib.RougeScorer( | |
| ["rougeL"], use_stemmer=True) | |
| _embed_tok_inst = AutoTokenizer.from_pretrained(HF_EMBED_MODEL) | |
| _embed_model_inst = AutoModel.from_pretrained( | |
| HF_EMBED_MODEL).cpu().eval() | |
| _nli_pipe_inst = pipeline( | |
| "zero-shot-classification", | |
| model=HF_NLI_MODEL, | |
| device=-1, | |
| tokenizer=AutoTokenizer.from_pretrained( | |
| HF_NLI_MODEL, use_fast=False), | |
| ) | |
| print("Reward judges ready") | |
| return _rouge_scorer_inst, _embed_tok_inst, _embed_model_inst, _nli_pipe_inst | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Reward components β IDENTICAL to inference.py XRayRewardEvaluator | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _rouge_l(pred: str, ref: str, scorer) -> float: | |
| if not pred.strip() or not ref.strip(): | |
| return 0.0 | |
| return scorer.score(ref, pred)["rougeL"].fmeasure | |
| def _negation_safety(pred: str, ref: str) -> float: | |
| pl, rl = pred.lower(), ref.lower() | |
| penalty = 0.0 | |
| for neg, pos in NEGATION_PAIRS: | |
| if neg in pl and pos in rl and neg not in rl: penalty += 0.25 | |
| if neg in rl and pos in pl and neg not in pl: penalty += 0.25 | |
| return max(0.0, 1.0 - penalty) | |
| def _mean_pool(tok_emb, attn_mask): | |
| mask = attn_mask.unsqueeze(-1).expand(tok_emb.size()).float() | |
| return torch.sum(tok_emb * mask, 1) / torch.clamp(mask.sum(1), min=1e-9) | |
| def _biomed_sim(pred: str, ref: str, embed_tok, embed_model) -> float: | |
| enc = embed_tok( | |
| [pred, ref], padding=True, truncation=True, | |
| max_length=512, return_tensors="pt" | |
| ).to(HF_JUDGE_DEVICE) | |
| with torch.no_grad(): | |
| out = embed_model(**enc) | |
| emb = F.normalize( | |
| _mean_pool(out.last_hidden_state, enc["attention_mask"]), p=2, dim=1) | |
| return ((emb[0] * emb[1]).sum().item() + 1.0) / 2.0 | |
| def _nli_quality(pred: str, nli_pipe) -> float: | |
| text = pred[:512] if pred.strip() else "no findings" | |
| result = nli_pipe(text, candidate_labels=NLI_LABELS, multi_label=False) | |
| lsm = dict(zip(result["labels"], result["scores"])) | |
| return float(lsm.get(NLI_LABELS[0], 0.5)) | |
| def _baseline_rouge(ref: str, corpus_reports: list[str], scorer) -> float: | |
| """BM25 baseline β only used if corpus is passed (optional for server).""" | |
| if not corpus_reports: | |
| return 0.0 | |
| from rank_bm25 import BM25Okapi | |
| bm25 = BM25Okapi([r.lower().split() for r in corpus_reports]) | |
| scores = bm25.get_scores(ref.lower().split()) | |
| top_idx = int(scores.argmax()) | |
| return _rouge_l(corpus_reports[top_idx], ref, scorer) | |
| def full_reward_breakdown(pred: str, ref: str, | |
| corpus: list[str] | None = None) -> dict: | |
| """ | |
| IDENTICAL to inference.py XRayRewardEvaluator.breakdown(). | |
| contrastive uses BM25 baseline if corpus provided, else uses rouge_l directly. | |
| """ | |
| pred = pred.strip() or "no findings noted" | |
| ref = ref.strip() or "no findings noted" | |
| scorer, embed_tok, embed_model, nli_pipe = _get_judges() | |
| rl = _rouge_l(pred, ref, scorer) | |
| neg_safety = _negation_safety(pred, ref) | |
| if corpus: | |
| baseline = _baseline_rouge(ref, corpus, scorer) | |
| delta = rl - baseline | |
| contrastive = (max(-1.0, min(1.0, delta)) + 1.0) / 2.0 | |
| else: | |
| contrastive = (max(-1.0, min(1.0, rl)) + 1.0) / 2.0 | |
| biomed_sim = _biomed_sim(pred, ref, embed_tok, embed_model) | |
| nli_q = _nli_quality(pred, nli_pipe) | |
| hf_judge = (biomed_sim + nli_q) / 2.0 | |
| total = 0.25 * contrastive + 0.25 * rl + 0.25 * neg_safety + 0.25 * hf_judge | |
| return { | |
| "contrastive": round(contrastive, 4), | |
| "rouge_l": round(rl, 4), | |
| "negation_safety": round(neg_safety, 4), | |
| "hf_judge": round(hf_judge, 4), | |
| "biomed_sim": round(biomed_sim, 4), | |
| "nli_quality": round(nli_q, 4), | |
| "total": round(total, 4), | |
| } | |
| def quick_reward(report: str) -> tuple[float, str]: | |
| """Fast proxy score β no reference or judge models needed.""" | |
| KEY = ["lung", "heart", "normal", "clear", "opacity", "infiltrate", | |
| "cardiomegaly", "pleural", "pulmonary", "chest", "thorax", | |
| "pneumonia", "edema", "effusion", "consolidation"] | |
| rl = report.lower() | |
| present = [t for t in KEY if t in rl] | |
| words = len(report.split()) | |
| term_s = len(present) / len(KEY) | |
| comp_s = min(1.0, words / 100.0) | |
| struct_s = 1.0 if 50 <= words <= 150 else 0.5 | |
| score = max(0.0, min(1.0, term_s*0.4 + comp_s*0.3 + struct_s*0.3)) | |
| fb = (f"Reward Score: {score:.2f} | Medical Terminology: {term_s:.1%} | " | |
| f"Clinical Completeness: {comp_s:.1%} | " | |
| f"Report Structure: {struct_s:.1%}") | |
| return score, fb | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # FastAPI | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| app = FastAPI(title="BioStack API") | |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], | |
| allow_methods=["*"], allow_headers=["*"]) | |
| def health(): | |
| return {"status": "ok", "device": str(DEVICE)} | |
| async def sft_endpoint(file: UploadFile = File(...)): | |
| try: | |
| tensor = preprocess(await file.read()) | |
| report = generate_report(sft_model, tensor) | |
| return {"report": report} | |
| except Exception as e: | |
| traceback.print_exc() | |
| return {"report": f"ERROR: {e}"} | |
| async def reward_endpoint( | |
| file: UploadFile = File(...), | |
| ground_truth: str = Form(default=""), | |
| ): | |
| try: | |
| tensor = preprocess(await file.read()) | |
| report = generate_report(sft_model, tensor) | |
| if ground_truth.strip(): | |
| bd = full_reward_breakdown(report, ground_truth) | |
| return { | |
| "score": bd["total"], | |
| "feedback": ( | |
| f"Reward Score: {bd['total']:.4f} | " | |
| f"Contrastive: {bd['contrastive']:.4f} | " | |
| f"ROUGE-L: {bd['rouge_l']:.4f} | " | |
| f"Negation Safety: {bd['negation_safety']:.4f} | " | |
| f"HF Judge: {bd['hf_judge']:.4f}" | |
| ), | |
| "sft_report": report, | |
| # ββ NEW: expose hf_judge directly for frontend comparison ββ | |
| "hf_judge": bd["hf_judge"], | |
| "breakdown": bd, | |
| "has_breakdown": True, | |
| } | |
| else: | |
| score, feedback = quick_reward(report) | |
| return { | |
| "score": score, | |
| "feedback": feedback, | |
| "sft_report": report, | |
| # No ground truth β hf_judge not available | |
| "hf_judge": None, | |
| "breakdown": None, | |
| "has_breakdown": False, | |
| } | |
| except Exception as e: | |
| traceback.print_exc() | |
| return {"score": 0.0, "feedback": f"ERROR: {e}", | |
| "sft_report": "", "hf_judge": None, | |
| "breakdown": None, "has_breakdown": False} | |
| async def grpo_reward_endpoint( | |
| file: UploadFile = File(...), | |
| ground_truth: str = Form(default=""), | |
| ): | |
| try: | |
| tensor = preprocess(await file.read()) | |
| report = generate_report(grpo_model, tensor) | |
| if ground_truth.strip(): | |
| bd = full_reward_breakdown(report, ground_truth) | |
| return { | |
| "report": report, | |
| # ββ NEW: expose hf_judge directly for frontend comparison ββ | |
| "hf_judge": bd["hf_judge"], | |
| "breakdown": bd, | |
| "has_breakdown": True, | |
| } | |
| return { | |
| "report": report, | |
| # No ground truth β hf_judge not available | |
| "hf_judge": None, | |
| "breakdown": None, | |
| "has_breakdown": False, | |
| } | |
| except Exception as e: | |
| traceback.print_exc() | |
| return {"report": f"ERROR: {e}", "hf_judge": None, | |
| "breakdown": None, "has_breakdown": False} | |
| # Serve React build AFTER all API routes | |
| if os.path.exists("build"): | |
| app.mount("/", StaticFiles(directory="build", html=True), name="static") | |
| print("React build mounted at /") | |
| else: | |
| print("WARNING: ./build not found") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860, reload=False) |