AE-Shree's picture
Update server.py
a818d02 verified
"""
server.py β€” BioStack FastAPI Backend
Fully aligned to inference.py:
βœ… Same CoAtNetEncoder : self.encoder + global_pool="avg" (NOT backbone)
βœ… Same VisionT5ForGRPO : class name, generate() method
βœ… Same image preprocess: raw /255.0, no ImageNet normalize
βœ… Same generation : do_sample=True, temperature=0.9, top_p=0.999
βœ… Same checkpoint logic: GRPO checkpoint-* > SFT best_model.pt
βœ… Same reward : contrastive + rouge_l + negation_safety + hf_judge (0.25 each)
βœ… Same state-dict keys : img_encoder.encoder.* (no _remap needed)
βœ… HF Judge score exposed for both SFT and GRPO comparison
"""
import io, os, glob, traceback, gc, time
from pathlib import Path
from threading import Lock
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import numpy as np
from PIL import Image
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from transformers import T5ForConditionalGeneration, T5Tokenizer
from huggingface_hub import hf_hub_download
from rouge_score import rouge_scorer as rouge_scorer_lib
# ─────────────────────────────────────────────────────────
# Config β€” mirrors inference.py exactly
# ─────────────────────────────────────────────────────────
HF_REPO = "AE-Shree/Biostack-Xray-NeMoGym"
SFT_FILE = "best_model.pt"
GRPO_SUB = "checkpoint-300"
GRPO_BIN = "pytorch_model.bin"
GRPO_SAFE = "model.safetensors"
MODEL_DIR = Path("models")
MODEL_DIR.mkdir(exist_ok=True)
IMAGE_SIZE = 224
HF_EMBED_MODEL = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext"
HF_NLI_MODEL = "cross-encoder/nli-roberta-base"
HF_JUDGE_DEVICE = "cpu"
NEGATION_PAIRS = [
("no pneumonia", "pneumonia"),
("no effusion", "effusion"),
("no consolidation", "consolidation"),
("no cardiomegaly", "cardiomegaly"),
("no opacity", "opacity"),
("no infiltrate", "infiltrate"),
("clear lungs", "opacity"),
("clear lungs", "consolidation"),
("normal", "abnormal"),
]
NLI_LABELS = [
"accurate and complete radiology report",
"inaccurate or incomplete radiology report",
]
DEVICE = torch.device("cpu")
print(f"Device: {DEVICE}")
# ─────────────────────────────────────────────────────────
# Image preprocessing β€” IDENTICAL to inference.py
# Raw /255.0, permute, no ImageNet normalization
# ─────────────────────────────────────────────────────────
def preprocess(file_bytes: bytes) -> torch.Tensor:
img = Image.open(io.BytesIO(file_bytes)).convert("RGB").resize((IMAGE_SIZE, IMAGE_SIZE))
arr = np.array(img, dtype=np.float32) / 255.0
tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0)
return tensor.to(DEVICE)
# ─────────────────────────────────────────────────────────
# Model Architecture β€” IDENTICAL to inference.py
# CoAtNetEncoder uses self.encoder + global_pool="avg"
# Class name is VisionT5ForGRPO (not VisionT5)
# ─────────────────────────────────────────────────────────
class CoAtNetEncoder(nn.Module):
def __init__(self):
super().__init__()
self.encoder = timm.create_model(
"coatnet_1_rw_224", pretrained=False, num_classes=0, global_pool="avg"
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.encoder(x)
class VisionT5ForGRPO(nn.Module):
def __init__(self):
super().__init__()
self.img_encoder = CoAtNetEncoder()
self.t5 = T5ForConditionalGeneration.from_pretrained("t5-small")
self.proj = nn.Linear(768, self.t5.config.d_model)
def _encode_image(self, pixel_values: torch.Tensor):
feats = self.proj(self.img_encoder(pixel_values)).unsqueeze(1)
enc_out = self.t5.encoder(inputs_embeds=feats)
enc_attn = torch.ones(feats.shape[:2], dtype=torch.long, device=feats.device)
return enc_out, enc_attn
def generate(self, pixel_values: torch.Tensor, **kwargs) -> torch.Tensor:
enc_out, enc_attn = self._encode_image(pixel_values)
return self.t5.generate(
encoder_outputs=enc_out,
attention_mask=enc_attn,
**kwargs,
)
# ─────────────────────────────────────────────────────────
# Checkpoint helpers β€” aligned to inference.py logic
# No _remap() needed: keys are img_encoder.encoder.* as saved
# ─────────────────────────────────────────────────────────
def _load_sd(path: str) -> dict:
if path.endswith(".safetensors"):
from safetensors.torch import load_file
sd = load_file(path, device="cpu")
else:
sd = torch.load(path, map_location="cpu", weights_only=False)
for wrap in ("state_dict", "model_state_dict", "model"):
if wrap in sd and isinstance(sd[wrap], dict):
sd = sd[wrap]
break
return sd
def _ensure(filename, subfolder=None) -> str:
fname = f"{subfolder}_{filename}" if subfolder else filename
local = MODEL_DIR / fname
if local.exists():
return str(local)
kw = dict(repo_id=HF_REPO, filename=filename, local_dir=str(MODEL_DIR))
if subfolder:
kw["subfolder"] = subfolder
return hf_hub_download(**kw)
def _build(path: str) -> VisionT5ForGRPO:
m = VisionT5ForGRPO()
m.load_state_dict(_load_sd(path), strict=True) # strict=True β€” same as inference.py
m.eval()
for p in m.parameters():
p.requires_grad_(False)
return m
def _find_grpo_checkpoint() -> str | None:
"""
Mirror of inference.py find_best_checkpoint():
scan MODEL_DIR for checkpoint-* folders, prefer latest.
"""
ckpt_dirs = glob.glob(str(MODEL_DIR / "checkpoint-*"))
if ckpt_dirs:
ckpt_dirs = sorted(ckpt_dirs, key=lambda x: int(x.split("checkpoint-")[-1]))
latest = ckpt_dirs[-1]
for fname in (GRPO_BIN, GRPO_SAFE):
p = os.path.join(latest, fname)
if os.path.exists(p):
return p
return None
# ─────────────────────────────────────────────────────────
# Load models at startup
# ─────────────────────────────────────────────────────────
print("\n" + "="*60)
print("LOADING TOKENIZER")
tokenizer = T5Tokenizer.from_pretrained(
"t5-small", truncation_side="left", padding_side="left", legacy=True)
print("OK")
print("LOADING SFT")
t0 = time.time()
sft_model = _build(_ensure(SFT_FILE))
print(f"SFT OK ({time.time()-t0:.1f}s)")
print("LOADING GRPO")
t0 = time.time()
try:
grpo_path = _ensure(GRPO_BIN, GRPO_SUB)
except Exception:
grpo_path = _ensure(GRPO_SAFE, GRPO_SUB)
grpo_model = _build(grpo_path)
print(f"GRPO OK ({time.time()-t0:.1f}s)")
gc.collect()
print("="*60 + "\nVisionT5ForGRPO models ready β€” reward judges load on first use\n" + "="*60)
# ─────────────────────────────────────────────────────────
# Report generation β€” IDENTICAL sampling params to inference.py
# do_sample=True, temperature=0.9, top_p=0.999, repetition_penalty=1.2
# ─────────────────────────────────────────────────────────
def _remove_numbers(text: str) -> str:
"""Remove all numbers from the generated text."""
import re
# Remove digits and decimal numbers
return re.sub(r'\d+\.?\d*', '', text).strip()
def generate_report(model: VisionT5ForGRPO, pixel_values: torch.Tensor) -> str:
with torch.no_grad():
out = model.generate(
pixel_values=pixel_values,
max_new_tokens=200,
do_sample=True,
temperature=0.9,
top_p=0.999,
repetition_penalty=1.2,
)
raw_text = tokenizer.decode(out[0], skip_special_tokens=True).strip()
return _remove_numbers(raw_text)
# ─────────────────────────────────────────────────────────
# Lazy reward judge singletons
# ─────────────────────────────────────────────────────────
_rouge_scorer_inst = None
_embed_tok_inst = None
_embed_model_inst = None
_nli_pipe_inst = None
_judge_lock = Lock()
def _get_judges():
global _rouge_scorer_inst, _embed_tok_inst, _embed_model_inst, _nli_pipe_inst
with _judge_lock:
if _rouge_scorer_inst is None:
print("Loading reward judges...")
from transformers import AutoModel, AutoTokenizer, pipeline
_rouge_scorer_inst = rouge_scorer_lib.RougeScorer(
["rougeL"], use_stemmer=True)
_embed_tok_inst = AutoTokenizer.from_pretrained(HF_EMBED_MODEL)
_embed_model_inst = AutoModel.from_pretrained(
HF_EMBED_MODEL).cpu().eval()
_nli_pipe_inst = pipeline(
"zero-shot-classification",
model=HF_NLI_MODEL,
device=-1,
tokenizer=AutoTokenizer.from_pretrained(
HF_NLI_MODEL, use_fast=False),
)
print("Reward judges ready")
return _rouge_scorer_inst, _embed_tok_inst, _embed_model_inst, _nli_pipe_inst
# ─────────────────────────────────────────────────────────
# Reward components β€” IDENTICAL to inference.py XRayRewardEvaluator
# ─────────────────────────────────────────────────────────
def _rouge_l(pred: str, ref: str, scorer) -> float:
if not pred.strip() or not ref.strip():
return 0.0
return scorer.score(ref, pred)["rougeL"].fmeasure
def _negation_safety(pred: str, ref: str) -> float:
pl, rl = pred.lower(), ref.lower()
penalty = 0.0
for neg, pos in NEGATION_PAIRS:
if neg in pl and pos in rl and neg not in rl: penalty += 0.25
if neg in rl and pos in pl and neg not in pl: penalty += 0.25
return max(0.0, 1.0 - penalty)
def _mean_pool(tok_emb, attn_mask):
mask = attn_mask.unsqueeze(-1).expand(tok_emb.size()).float()
return torch.sum(tok_emb * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)
def _biomed_sim(pred: str, ref: str, embed_tok, embed_model) -> float:
enc = embed_tok(
[pred, ref], padding=True, truncation=True,
max_length=512, return_tensors="pt"
).to(HF_JUDGE_DEVICE)
with torch.no_grad():
out = embed_model(**enc)
emb = F.normalize(
_mean_pool(out.last_hidden_state, enc["attention_mask"]), p=2, dim=1)
return ((emb[0] * emb[1]).sum().item() + 1.0) / 2.0
def _nli_quality(pred: str, nli_pipe) -> float:
text = pred[:512] if pred.strip() else "no findings"
result = nli_pipe(text, candidate_labels=NLI_LABELS, multi_label=False)
lsm = dict(zip(result["labels"], result["scores"]))
return float(lsm.get(NLI_LABELS[0], 0.5))
def _baseline_rouge(ref: str, corpus_reports: list[str], scorer) -> float:
"""BM25 baseline β€” only used if corpus is passed (optional for server)."""
if not corpus_reports:
return 0.0
from rank_bm25 import BM25Okapi
bm25 = BM25Okapi([r.lower().split() for r in corpus_reports])
scores = bm25.get_scores(ref.lower().split())
top_idx = int(scores.argmax())
return _rouge_l(corpus_reports[top_idx], ref, scorer)
def full_reward_breakdown(pred: str, ref: str,
corpus: list[str] | None = None) -> dict:
"""
IDENTICAL to inference.py XRayRewardEvaluator.breakdown().
contrastive uses BM25 baseline if corpus provided, else uses rouge_l directly.
"""
pred = pred.strip() or "no findings noted"
ref = ref.strip() or "no findings noted"
scorer, embed_tok, embed_model, nli_pipe = _get_judges()
rl = _rouge_l(pred, ref, scorer)
neg_safety = _negation_safety(pred, ref)
if corpus:
baseline = _baseline_rouge(ref, corpus, scorer)
delta = rl - baseline
contrastive = (max(-1.0, min(1.0, delta)) + 1.0) / 2.0
else:
contrastive = (max(-1.0, min(1.0, rl)) + 1.0) / 2.0
biomed_sim = _biomed_sim(pred, ref, embed_tok, embed_model)
nli_q = _nli_quality(pred, nli_pipe)
hf_judge = (biomed_sim + nli_q) / 2.0
total = 0.25 * contrastive + 0.25 * rl + 0.25 * neg_safety + 0.25 * hf_judge
return {
"contrastive": round(contrastive, 4),
"rouge_l": round(rl, 4),
"negation_safety": round(neg_safety, 4),
"hf_judge": round(hf_judge, 4),
"biomed_sim": round(biomed_sim, 4),
"nli_quality": round(nli_q, 4),
"total": round(total, 4),
}
def quick_reward(report: str) -> tuple[float, str]:
"""Fast proxy score β€” no reference or judge models needed."""
KEY = ["lung", "heart", "normal", "clear", "opacity", "infiltrate",
"cardiomegaly", "pleural", "pulmonary", "chest", "thorax",
"pneumonia", "edema", "effusion", "consolidation"]
rl = report.lower()
present = [t for t in KEY if t in rl]
words = len(report.split())
term_s = len(present) / len(KEY)
comp_s = min(1.0, words / 100.0)
struct_s = 1.0 if 50 <= words <= 150 else 0.5
score = max(0.0, min(1.0, term_s*0.4 + comp_s*0.3 + struct_s*0.3))
fb = (f"Reward Score: {score:.2f} | Medical Terminology: {term_s:.1%} | "
f"Clinical Completeness: {comp_s:.1%} | "
f"Report Structure: {struct_s:.1%}")
return score, fb
# ─────────────────────────────────────────────────────────
# FastAPI
# ─────────────────────────────────────────────────────────
app = FastAPI(title="BioStack API")
app.add_middleware(CORSMiddleware, allow_origins=["*"],
allow_methods=["*"], allow_headers=["*"])
@app.get("/health")
def health():
return {"status": "ok", "device": str(DEVICE)}
@app.post("/sft")
async def sft_endpoint(file: UploadFile = File(...)):
try:
tensor = preprocess(await file.read())
report = generate_report(sft_model, tensor)
return {"report": report}
except Exception as e:
traceback.print_exc()
return {"report": f"ERROR: {e}"}
@app.post("/reward")
async def reward_endpoint(
file: UploadFile = File(...),
ground_truth: str = Form(default=""),
):
try:
tensor = preprocess(await file.read())
report = generate_report(sft_model, tensor)
if ground_truth.strip():
bd = full_reward_breakdown(report, ground_truth)
return {
"score": bd["total"],
"feedback": (
f"Reward Score: {bd['total']:.4f} | "
f"Contrastive: {bd['contrastive']:.4f} | "
f"ROUGE-L: {bd['rouge_l']:.4f} | "
f"Negation Safety: {bd['negation_safety']:.4f} | "
f"HF Judge: {bd['hf_judge']:.4f}"
),
"sft_report": report,
# ── NEW: expose hf_judge directly for frontend comparison ──
"hf_judge": bd["hf_judge"],
"breakdown": bd,
"has_breakdown": True,
}
else:
score, feedback = quick_reward(report)
return {
"score": score,
"feedback": feedback,
"sft_report": report,
# No ground truth β†’ hf_judge not available
"hf_judge": None,
"breakdown": None,
"has_breakdown": False,
}
except Exception as e:
traceback.print_exc()
return {"score": 0.0, "feedback": f"ERROR: {e}",
"sft_report": "", "hf_judge": None,
"breakdown": None, "has_breakdown": False}
@app.post("/grpo_reward")
async def grpo_reward_endpoint(
file: UploadFile = File(...),
ground_truth: str = Form(default=""),
):
try:
tensor = preprocess(await file.read())
report = generate_report(grpo_model, tensor)
if ground_truth.strip():
bd = full_reward_breakdown(report, ground_truth)
return {
"report": report,
# ── NEW: expose hf_judge directly for frontend comparison ──
"hf_judge": bd["hf_judge"],
"breakdown": bd,
"has_breakdown": True,
}
return {
"report": report,
# No ground truth β†’ hf_judge not available
"hf_judge": None,
"breakdown": None,
"has_breakdown": False,
}
except Exception as e:
traceback.print_exc()
return {"report": f"ERROR: {e}", "hf_judge": None,
"breakdown": None, "has_breakdown": False}
# Serve React build AFTER all API routes
if os.path.exists("build"):
app.mount("/", StaticFiles(directory="build", html=True), name="static")
print("React build mounted at /")
else:
print("WARNING: ./build not found")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860, reload=False)