""" server.py — BioStack FastAPI Backend Fully aligned to inference.py: ✅ Same CoAtNetEncoder : self.encoder + global_pool="avg" (NOT backbone) ✅ Same VisionT5ForGRPO : class name, generate() method ✅ Same image preprocess: raw /255.0, no ImageNet normalize ✅ Same generation : do_sample=True, temperature=0.9, top_p=0.999 ✅ Same checkpoint logic: GRPO checkpoint-* > SFT best_model.pt ✅ Same reward : contrastive + rouge_l + negation_safety + hf_judge (0.25 each) ✅ Same state-dict keys : img_encoder.encoder.* (no _remap needed) ✅ HF Judge score exposed for both SFT and GRPO comparison """ import io, os, glob, traceback, gc, time from pathlib import Path from threading import Lock import torch import torch.nn as nn import torch.nn.functional as F import timm import numpy as np from PIL import Image from fastapi import FastAPI, File, UploadFile, Form from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from transformers import T5ForConditionalGeneration, T5Tokenizer from huggingface_hub import hf_hub_download from rouge_score import rouge_scorer as rouge_scorer_lib # ───────────────────────────────────────────────────────── # Config — mirrors inference.py exactly # ───────────────────────────────────────────────────────── HF_REPO = "AE-Shree/Biostack-Xray-NeMoGym" SFT_FILE = "best_model.pt" GRPO_SUB = "checkpoint-300" GRPO_BIN = "pytorch_model.bin" GRPO_SAFE = "model.safetensors" MODEL_DIR = Path("models") MODEL_DIR.mkdir(exist_ok=True) IMAGE_SIZE = 224 HF_EMBED_MODEL = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext" HF_NLI_MODEL = "cross-encoder/nli-roberta-base" HF_JUDGE_DEVICE = "cpu" NEGATION_PAIRS = [ ("no pneumonia", "pneumonia"), ("no effusion", "effusion"), ("no consolidation", "consolidation"), ("no cardiomegaly", "cardiomegaly"), ("no opacity", "opacity"), ("no infiltrate", "infiltrate"), ("clear lungs", "opacity"), ("clear lungs", "consolidation"), ("normal", "abnormal"), ] NLI_LABELS = [ "accurate and complete radiology report", "inaccurate or incomplete radiology report", ] DEVICE = torch.device("cpu") print(f"Device: {DEVICE}") # ───────────────────────────────────────────────────────── # Image preprocessing — IDENTICAL to inference.py # Raw /255.0, permute, no ImageNet normalization # ───────────────────────────────────────────────────────── def preprocess(file_bytes: bytes) -> torch.Tensor: img = Image.open(io.BytesIO(file_bytes)).convert("RGB").resize((IMAGE_SIZE, IMAGE_SIZE)) arr = np.array(img, dtype=np.float32) / 255.0 tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0) return tensor.to(DEVICE) # ───────────────────────────────────────────────────────── # Model Architecture — IDENTICAL to inference.py # CoAtNetEncoder uses self.encoder + global_pool="avg" # Class name is VisionT5ForGRPO (not VisionT5) # ───────────────────────────────────────────────────────── class CoAtNetEncoder(nn.Module): def __init__(self): super().__init__() self.encoder = timm.create_model( "coatnet_1_rw_224", pretrained=False, num_classes=0, global_pool="avg" ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.encoder(x) class VisionT5ForGRPO(nn.Module): def __init__(self): super().__init__() self.img_encoder = CoAtNetEncoder() self.t5 = T5ForConditionalGeneration.from_pretrained("t5-small") self.proj = nn.Linear(768, self.t5.config.d_model) def _encode_image(self, pixel_values: torch.Tensor): feats = self.proj(self.img_encoder(pixel_values)).unsqueeze(1) enc_out = self.t5.encoder(inputs_embeds=feats) enc_attn = torch.ones(feats.shape[:2], dtype=torch.long, device=feats.device) return enc_out, enc_attn def generate(self, pixel_values: torch.Tensor, **kwargs) -> torch.Tensor: enc_out, enc_attn = self._encode_image(pixel_values) return self.t5.generate( encoder_outputs=enc_out, attention_mask=enc_attn, **kwargs, ) # ───────────────────────────────────────────────────────── # Checkpoint helpers — aligned to inference.py logic # No _remap() needed: keys are img_encoder.encoder.* as saved # ───────────────────────────────────────────────────────── def _load_sd(path: str) -> dict: if path.endswith(".safetensors"): from safetensors.torch import load_file sd = load_file(path, device="cpu") else: sd = torch.load(path, map_location="cpu", weights_only=False) for wrap in ("state_dict", "model_state_dict", "model"): if wrap in sd and isinstance(sd[wrap], dict): sd = sd[wrap] break return sd def _ensure(filename, subfolder=None) -> str: fname = f"{subfolder}_{filename}" if subfolder else filename local = MODEL_DIR / fname if local.exists(): return str(local) kw = dict(repo_id=HF_REPO, filename=filename, local_dir=str(MODEL_DIR)) if subfolder: kw["subfolder"] = subfolder return hf_hub_download(**kw) def _build(path: str) -> VisionT5ForGRPO: m = VisionT5ForGRPO() m.load_state_dict(_load_sd(path), strict=True) # strict=True — same as inference.py m.eval() for p in m.parameters(): p.requires_grad_(False) return m def _find_grpo_checkpoint() -> str | None: """ Mirror of inference.py find_best_checkpoint(): scan MODEL_DIR for checkpoint-* folders, prefer latest. """ ckpt_dirs = glob.glob(str(MODEL_DIR / "checkpoint-*")) if ckpt_dirs: ckpt_dirs = sorted(ckpt_dirs, key=lambda x: int(x.split("checkpoint-")[-1])) latest = ckpt_dirs[-1] for fname in (GRPO_BIN, GRPO_SAFE): p = os.path.join(latest, fname) if os.path.exists(p): return p return None # ───────────────────────────────────────────────────────── # Load models at startup # ───────────────────────────────────────────────────────── print("\n" + "="*60) print("LOADING TOKENIZER") tokenizer = T5Tokenizer.from_pretrained( "t5-small", truncation_side="left", padding_side="left", legacy=True) print("OK") print("LOADING SFT") t0 = time.time() sft_model = _build(_ensure(SFT_FILE)) print(f"SFT OK ({time.time()-t0:.1f}s)") print("LOADING GRPO") t0 = time.time() try: grpo_path = _ensure(GRPO_BIN, GRPO_SUB) except Exception: grpo_path = _ensure(GRPO_SAFE, GRPO_SUB) grpo_model = _build(grpo_path) print(f"GRPO OK ({time.time()-t0:.1f}s)") gc.collect() print("="*60 + "\nVisionT5ForGRPO models ready — reward judges load on first use\n" + "="*60) # ───────────────────────────────────────────────────────── # Report generation — IDENTICAL sampling params to inference.py # do_sample=True, temperature=0.9, top_p=0.999, repetition_penalty=1.2 # ───────────────────────────────────────────────────────── def _remove_numbers(text: str) -> str: """Remove all numbers from the generated text.""" import re # Remove digits and decimal numbers return re.sub(r'\d+\.?\d*', '', text).strip() def generate_report(model: VisionT5ForGRPO, pixel_values: torch.Tensor) -> str: with torch.no_grad(): out = model.generate( pixel_values=pixel_values, max_new_tokens=200, do_sample=True, temperature=0.9, top_p=0.999, repetition_penalty=1.2, ) raw_text = tokenizer.decode(out[0], skip_special_tokens=True).strip() return _remove_numbers(raw_text) # ───────────────────────────────────────────────────────── # Lazy reward judge singletons # ───────────────────────────────────────────────────────── _rouge_scorer_inst = None _embed_tok_inst = None _embed_model_inst = None _nli_pipe_inst = None _judge_lock = Lock() def _get_judges(): global _rouge_scorer_inst, _embed_tok_inst, _embed_model_inst, _nli_pipe_inst with _judge_lock: if _rouge_scorer_inst is None: print("Loading reward judges...") from transformers import AutoModel, AutoTokenizer, pipeline _rouge_scorer_inst = rouge_scorer_lib.RougeScorer( ["rougeL"], use_stemmer=True) _embed_tok_inst = AutoTokenizer.from_pretrained(HF_EMBED_MODEL) _embed_model_inst = AutoModel.from_pretrained( HF_EMBED_MODEL).cpu().eval() _nli_pipe_inst = pipeline( "zero-shot-classification", model=HF_NLI_MODEL, device=-1, tokenizer=AutoTokenizer.from_pretrained( HF_NLI_MODEL, use_fast=False), ) print("Reward judges ready") return _rouge_scorer_inst, _embed_tok_inst, _embed_model_inst, _nli_pipe_inst # ───────────────────────────────────────────────────────── # Reward components — IDENTICAL to inference.py XRayRewardEvaluator # ───────────────────────────────────────────────────────── def _rouge_l(pred: str, ref: str, scorer) -> float: if not pred.strip() or not ref.strip(): return 0.0 return scorer.score(ref, pred)["rougeL"].fmeasure def _negation_safety(pred: str, ref: str) -> float: pl, rl = pred.lower(), ref.lower() penalty = 0.0 for neg, pos in NEGATION_PAIRS: if neg in pl and pos in rl and neg not in rl: penalty += 0.25 if neg in rl and pos in pl and neg not in pl: penalty += 0.25 return max(0.0, 1.0 - penalty) def _mean_pool(tok_emb, attn_mask): mask = attn_mask.unsqueeze(-1).expand(tok_emb.size()).float() return torch.sum(tok_emb * mask, 1) / torch.clamp(mask.sum(1), min=1e-9) def _biomed_sim(pred: str, ref: str, embed_tok, embed_model) -> float: enc = embed_tok( [pred, ref], padding=True, truncation=True, max_length=512, return_tensors="pt" ).to(HF_JUDGE_DEVICE) with torch.no_grad(): out = embed_model(**enc) emb = F.normalize( _mean_pool(out.last_hidden_state, enc["attention_mask"]), p=2, dim=1) return ((emb[0] * emb[1]).sum().item() + 1.0) / 2.0 def _nli_quality(pred: str, nli_pipe) -> float: text = pred[:512] if pred.strip() else "no findings" result = nli_pipe(text, candidate_labels=NLI_LABELS, multi_label=False) lsm = dict(zip(result["labels"], result["scores"])) return float(lsm.get(NLI_LABELS[0], 0.5)) def _baseline_rouge(ref: str, corpus_reports: list[str], scorer) -> float: """BM25 baseline — only used if corpus is passed (optional for server).""" if not corpus_reports: return 0.0 from rank_bm25 import BM25Okapi bm25 = BM25Okapi([r.lower().split() for r in corpus_reports]) scores = bm25.get_scores(ref.lower().split()) top_idx = int(scores.argmax()) return _rouge_l(corpus_reports[top_idx], ref, scorer) def full_reward_breakdown(pred: str, ref: str, corpus: list[str] | None = None) -> dict: """ IDENTICAL to inference.py XRayRewardEvaluator.breakdown(). contrastive uses BM25 baseline if corpus provided, else uses rouge_l directly. """ pred = pred.strip() or "no findings noted" ref = ref.strip() or "no findings noted" scorer, embed_tok, embed_model, nli_pipe = _get_judges() rl = _rouge_l(pred, ref, scorer) neg_safety = _negation_safety(pred, ref) if corpus: baseline = _baseline_rouge(ref, corpus, scorer) delta = rl - baseline contrastive = (max(-1.0, min(1.0, delta)) + 1.0) / 2.0 else: contrastive = (max(-1.0, min(1.0, rl)) + 1.0) / 2.0 biomed_sim = _biomed_sim(pred, ref, embed_tok, embed_model) nli_q = _nli_quality(pred, nli_pipe) hf_judge = (biomed_sim + nli_q) / 2.0 total = 0.25 * contrastive + 0.25 * rl + 0.25 * neg_safety + 0.25 * hf_judge return { "contrastive": round(contrastive, 4), "rouge_l": round(rl, 4), "negation_safety": round(neg_safety, 4), "hf_judge": round(hf_judge, 4), "biomed_sim": round(biomed_sim, 4), "nli_quality": round(nli_q, 4), "total": round(total, 4), } def quick_reward(report: str) -> tuple[float, str]: """Fast proxy score — no reference or judge models needed.""" KEY = ["lung", "heart", "normal", "clear", "opacity", "infiltrate", "cardiomegaly", "pleural", "pulmonary", "chest", "thorax", "pneumonia", "edema", "effusion", "consolidation"] rl = report.lower() present = [t for t in KEY if t in rl] words = len(report.split()) term_s = len(present) / len(KEY) comp_s = min(1.0, words / 100.0) struct_s = 1.0 if 50 <= words <= 150 else 0.5 score = max(0.0, min(1.0, term_s*0.4 + comp_s*0.3 + struct_s*0.3)) fb = (f"Reward Score: {score:.2f} | Medical Terminology: {term_s:.1%} | " f"Clinical Completeness: {comp_s:.1%} | " f"Report Structure: {struct_s:.1%}") return score, fb # ───────────────────────────────────────────────────────── # FastAPI # ───────────────────────────────────────────────────────── app = FastAPI(title="BioStack API") app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) @app.get("/health") def health(): return {"status": "ok", "device": str(DEVICE)} @app.post("/sft") async def sft_endpoint(file: UploadFile = File(...)): try: tensor = preprocess(await file.read()) report = generate_report(sft_model, tensor) return {"report": report} except Exception as e: traceback.print_exc() return {"report": f"ERROR: {e}"} @app.post("/reward") async def reward_endpoint( file: UploadFile = File(...), ground_truth: str = Form(default=""), ): try: tensor = preprocess(await file.read()) report = generate_report(sft_model, tensor) if ground_truth.strip(): bd = full_reward_breakdown(report, ground_truth) return { "score": bd["total"], "feedback": ( f"Reward Score: {bd['total']:.4f} | " f"Contrastive: {bd['contrastive']:.4f} | " f"ROUGE-L: {bd['rouge_l']:.4f} | " f"Negation Safety: {bd['negation_safety']:.4f} | " f"HF Judge: {bd['hf_judge']:.4f}" ), "sft_report": report, # ── NEW: expose hf_judge directly for frontend comparison ── "hf_judge": bd["hf_judge"], "breakdown": bd, "has_breakdown": True, } else: score, feedback = quick_reward(report) return { "score": score, "feedback": feedback, "sft_report": report, # No ground truth → hf_judge not available "hf_judge": None, "breakdown": None, "has_breakdown": False, } except Exception as e: traceback.print_exc() return {"score": 0.0, "feedback": f"ERROR: {e}", "sft_report": "", "hf_judge": None, "breakdown": None, "has_breakdown": False} @app.post("/grpo_reward") async def grpo_reward_endpoint( file: UploadFile = File(...), ground_truth: str = Form(default=""), ): try: tensor = preprocess(await file.read()) report = generate_report(grpo_model, tensor) if ground_truth.strip(): bd = full_reward_breakdown(report, ground_truth) return { "report": report, # ── NEW: expose hf_judge directly for frontend comparison ── "hf_judge": bd["hf_judge"], "breakdown": bd, "has_breakdown": True, } return { "report": report, # No ground truth → hf_judge not available "hf_judge": None, "breakdown": None, "has_breakdown": False, } except Exception as e: traceback.print_exc() return {"report": f"ERROR: {e}", "hf_judge": None, "breakdown": None, "has_breakdown": False} # Serve React build AFTER all API routes if os.path.exists("build"): app.mount("/", StaticFiles(directory="build", html=True), name="static") print("React build mounted at /") else: print("WARNING: ./build not found") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860, reload=False)