Spaces:
Sleeping
Sleeping
File size: 18,977 Bytes
aec6aee fbd5b5a a461e7f d69870d a461e7f a818d02 aec6aee 417f379 aec6aee 4b3c283 aec6aee 44ac1e6 aec6aee a461e7f aec6aee 4b3c283 aec6aee 4b3c283 aec6aee 44ac1e6 aec6aee 44ac1e6 a461e7f 44ac1e6 a461e7f aec6aee a461e7f aec6aee a461e7f 44ac1e6 aec6aee a461e7f aec6aee 44ac1e6 a461e7f 44ac1e6 aec6aee a461e7f aec6aee 44ac1e6 a461e7f 44ac1e6 aec6aee a461e7f aec6aee a461e7f aec6aee a461e7f aec6aee a461e7f aec6aee a461e7f aec6aee 44ac1e6 a461e7f d69870d 44ac1e6 a461e7f aec6aee a461e7f aec6aee a461e7f 44ac1e6 fbd5b5a a461e7f aec6aee a461e7f d69870d 364e998 fbd5b5a 364e998 aec6aee a461e7f d69870d a461e7f 44ac1e6 a461e7f 44ac1e6 fbd5b5a 44ac1e6 fbd5b5a 44ac1e6 fbd5b5a 44ac1e6 fbd5b5a 44ac1e6 fbd5b5a 44ac1e6 fbd5b5a 44ac1e6 fbd5b5a 44ac1e6 4b3c283 a461e7f d69870d a461e7f b133163 d69870d b133163 a461e7f d69870d a461e7f b133163 4b3c283 44ac1e6 4b3c283 44ac1e6 4b3c283 76993ec 4b3c283 44ac1e6 a461e7f 44ac1e6 a461e7f 44ac1e6 4b3c283 44ac1e6 a461e7f 44ac1e6 a461e7f 44ac1e6 a461e7f 44ac1e6 4b3c283 a461e7f 44ac1e6 a461e7f 44ac1e6 a461e7f 44ac1e6 4b3c283 a461e7f 44ac1e6 4b3c283 44ac1e6 a461e7f 44ac1e6 a461e7f aec6aee a461e7f 44ac1e6 4b3c283 364e998 aec6aee 44ac1e6 364e998 44ac1e6 aec6aee 364e998 aec6aee a461e7f aec6aee a461e7f aec6aee a461e7f aec6aee a461e7f aec6aee 44ac1e6 4b3c283 44ac1e6 aec6aee a461e7f 44ac1e6 76993ec a818d02 76993ec 44ac1e6 4b3c283 a818d02 4b3c283 44ac1e6 aec6aee 44ac1e6 a818d02 aec6aee a461e7f 44ac1e6 4b3c283 44ac1e6 a461e7f 4b3c283 44ac1e6 76993ec a818d02 44ac1e6 a818d02 44ac1e6 a461e7f 4b3c283 aec6aee 364e998 aec6aee a461e7f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 | """
server.py β BioStack FastAPI Backend
Fully aligned to inference.py:
β
Same CoAtNetEncoder : self.encoder + global_pool="avg" (NOT backbone)
β
Same VisionT5ForGRPO : class name, generate() method
β
Same image preprocess: raw /255.0, no ImageNet normalize
β
Same generation : do_sample=True, temperature=0.9, top_p=0.999
β
Same checkpoint logic: GRPO checkpoint-* > SFT best_model.pt
β
Same reward : contrastive + rouge_l + negation_safety + hf_judge (0.25 each)
β
Same state-dict keys : img_encoder.encoder.* (no _remap needed)
β
HF Judge score exposed for both SFT and GRPO comparison
"""
import io, os, glob, traceback, gc, time
from pathlib import Path
from threading import Lock
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import numpy as np
from PIL import Image
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from transformers import T5ForConditionalGeneration, T5Tokenizer
from huggingface_hub import hf_hub_download
from rouge_score import rouge_scorer as rouge_scorer_lib
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Config β mirrors inference.py exactly
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
HF_REPO = "AE-Shree/Biostack-Xray-NeMoGym"
SFT_FILE = "best_model.pt"
GRPO_SUB = "checkpoint-300"
GRPO_BIN = "pytorch_model.bin"
GRPO_SAFE = "model.safetensors"
MODEL_DIR = Path("models")
MODEL_DIR.mkdir(exist_ok=True)
IMAGE_SIZE = 224
HF_EMBED_MODEL = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext"
HF_NLI_MODEL = "cross-encoder/nli-roberta-base"
HF_JUDGE_DEVICE = "cpu"
NEGATION_PAIRS = [
("no pneumonia", "pneumonia"),
("no effusion", "effusion"),
("no consolidation", "consolidation"),
("no cardiomegaly", "cardiomegaly"),
("no opacity", "opacity"),
("no infiltrate", "infiltrate"),
("clear lungs", "opacity"),
("clear lungs", "consolidation"),
("normal", "abnormal"),
]
NLI_LABELS = [
"accurate and complete radiology report",
"inaccurate or incomplete radiology report",
]
DEVICE = torch.device("cpu")
print(f"Device: {DEVICE}")
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Image preprocessing β IDENTICAL to inference.py
# Raw /255.0, permute, no ImageNet normalization
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def preprocess(file_bytes: bytes) -> torch.Tensor:
img = Image.open(io.BytesIO(file_bytes)).convert("RGB").resize((IMAGE_SIZE, IMAGE_SIZE))
arr = np.array(img, dtype=np.float32) / 255.0
tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0)
return tensor.to(DEVICE)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Model Architecture β IDENTICAL to inference.py
# CoAtNetEncoder uses self.encoder + global_pool="avg"
# Class name is VisionT5ForGRPO (not VisionT5)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class CoAtNetEncoder(nn.Module):
def __init__(self):
super().__init__()
self.encoder = timm.create_model(
"coatnet_1_rw_224", pretrained=False, num_classes=0, global_pool="avg"
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.encoder(x)
class VisionT5ForGRPO(nn.Module):
def __init__(self):
super().__init__()
self.img_encoder = CoAtNetEncoder()
self.t5 = T5ForConditionalGeneration.from_pretrained("t5-small")
self.proj = nn.Linear(768, self.t5.config.d_model)
def _encode_image(self, pixel_values: torch.Tensor):
feats = self.proj(self.img_encoder(pixel_values)).unsqueeze(1)
enc_out = self.t5.encoder(inputs_embeds=feats)
enc_attn = torch.ones(feats.shape[:2], dtype=torch.long, device=feats.device)
return enc_out, enc_attn
def generate(self, pixel_values: torch.Tensor, **kwargs) -> torch.Tensor:
enc_out, enc_attn = self._encode_image(pixel_values)
return self.t5.generate(
encoder_outputs=enc_out,
attention_mask=enc_attn,
**kwargs,
)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Checkpoint helpers β aligned to inference.py logic
# No _remap() needed: keys are img_encoder.encoder.* as saved
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _load_sd(path: str) -> dict:
if path.endswith(".safetensors"):
from safetensors.torch import load_file
sd = load_file(path, device="cpu")
else:
sd = torch.load(path, map_location="cpu", weights_only=False)
for wrap in ("state_dict", "model_state_dict", "model"):
if wrap in sd and isinstance(sd[wrap], dict):
sd = sd[wrap]
break
return sd
def _ensure(filename, subfolder=None) -> str:
fname = f"{subfolder}_{filename}" if subfolder else filename
local = MODEL_DIR / fname
if local.exists():
return str(local)
kw = dict(repo_id=HF_REPO, filename=filename, local_dir=str(MODEL_DIR))
if subfolder:
kw["subfolder"] = subfolder
return hf_hub_download(**kw)
def _build(path: str) -> VisionT5ForGRPO:
m = VisionT5ForGRPO()
m.load_state_dict(_load_sd(path), strict=True) # strict=True β same as inference.py
m.eval()
for p in m.parameters():
p.requires_grad_(False)
return m
def _find_grpo_checkpoint() -> str | None:
"""
Mirror of inference.py find_best_checkpoint():
scan MODEL_DIR for checkpoint-* folders, prefer latest.
"""
ckpt_dirs = glob.glob(str(MODEL_DIR / "checkpoint-*"))
if ckpt_dirs:
ckpt_dirs = sorted(ckpt_dirs, key=lambda x: int(x.split("checkpoint-")[-1]))
latest = ckpt_dirs[-1]
for fname in (GRPO_BIN, GRPO_SAFE):
p = os.path.join(latest, fname)
if os.path.exists(p):
return p
return None
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Load models at startup
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
print("\n" + "="*60)
print("LOADING TOKENIZER")
tokenizer = T5Tokenizer.from_pretrained(
"t5-small", truncation_side="left", padding_side="left", legacy=True)
print("OK")
print("LOADING SFT")
t0 = time.time()
sft_model = _build(_ensure(SFT_FILE))
print(f"SFT OK ({time.time()-t0:.1f}s)")
print("LOADING GRPO")
t0 = time.time()
try:
grpo_path = _ensure(GRPO_BIN, GRPO_SUB)
except Exception:
grpo_path = _ensure(GRPO_SAFE, GRPO_SUB)
grpo_model = _build(grpo_path)
print(f"GRPO OK ({time.time()-t0:.1f}s)")
gc.collect()
print("="*60 + "\nVisionT5ForGRPO models ready β reward judges load on first use\n" + "="*60)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Report generation β IDENTICAL sampling params to inference.py
# do_sample=True, temperature=0.9, top_p=0.999, repetition_penalty=1.2
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _remove_numbers(text: str) -> str:
"""Remove all numbers from the generated text."""
import re
# Remove digits and decimal numbers
return re.sub(r'\d+\.?\d*', '', text).strip()
def generate_report(model: VisionT5ForGRPO, pixel_values: torch.Tensor) -> str:
with torch.no_grad():
out = model.generate(
pixel_values=pixel_values,
max_new_tokens=200,
do_sample=True,
temperature=0.9,
top_p=0.999,
repetition_penalty=1.2,
)
raw_text = tokenizer.decode(out[0], skip_special_tokens=True).strip()
return _remove_numbers(raw_text)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Lazy reward judge singletons
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
_rouge_scorer_inst = None
_embed_tok_inst = None
_embed_model_inst = None
_nli_pipe_inst = None
_judge_lock = Lock()
def _get_judges():
global _rouge_scorer_inst, _embed_tok_inst, _embed_model_inst, _nli_pipe_inst
with _judge_lock:
if _rouge_scorer_inst is None:
print("Loading reward judges...")
from transformers import AutoModel, AutoTokenizer, pipeline
_rouge_scorer_inst = rouge_scorer_lib.RougeScorer(
["rougeL"], use_stemmer=True)
_embed_tok_inst = AutoTokenizer.from_pretrained(HF_EMBED_MODEL)
_embed_model_inst = AutoModel.from_pretrained(
HF_EMBED_MODEL).cpu().eval()
_nli_pipe_inst = pipeline(
"zero-shot-classification",
model=HF_NLI_MODEL,
device=-1,
tokenizer=AutoTokenizer.from_pretrained(
HF_NLI_MODEL, use_fast=False),
)
print("Reward judges ready")
return _rouge_scorer_inst, _embed_tok_inst, _embed_model_inst, _nli_pipe_inst
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Reward components β IDENTICAL to inference.py XRayRewardEvaluator
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _rouge_l(pred: str, ref: str, scorer) -> float:
if not pred.strip() or not ref.strip():
return 0.0
return scorer.score(ref, pred)["rougeL"].fmeasure
def _negation_safety(pred: str, ref: str) -> float:
pl, rl = pred.lower(), ref.lower()
penalty = 0.0
for neg, pos in NEGATION_PAIRS:
if neg in pl and pos in rl and neg not in rl: penalty += 0.25
if neg in rl and pos in pl and neg not in pl: penalty += 0.25
return max(0.0, 1.0 - penalty)
def _mean_pool(tok_emb, attn_mask):
mask = attn_mask.unsqueeze(-1).expand(tok_emb.size()).float()
return torch.sum(tok_emb * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)
def _biomed_sim(pred: str, ref: str, embed_tok, embed_model) -> float:
enc = embed_tok(
[pred, ref], padding=True, truncation=True,
max_length=512, return_tensors="pt"
).to(HF_JUDGE_DEVICE)
with torch.no_grad():
out = embed_model(**enc)
emb = F.normalize(
_mean_pool(out.last_hidden_state, enc["attention_mask"]), p=2, dim=1)
return ((emb[0] * emb[1]).sum().item() + 1.0) / 2.0
def _nli_quality(pred: str, nli_pipe) -> float:
text = pred[:512] if pred.strip() else "no findings"
result = nli_pipe(text, candidate_labels=NLI_LABELS, multi_label=False)
lsm = dict(zip(result["labels"], result["scores"]))
return float(lsm.get(NLI_LABELS[0], 0.5))
def _baseline_rouge(ref: str, corpus_reports: list[str], scorer) -> float:
"""BM25 baseline β only used if corpus is passed (optional for server)."""
if not corpus_reports:
return 0.0
from rank_bm25 import BM25Okapi
bm25 = BM25Okapi([r.lower().split() for r in corpus_reports])
scores = bm25.get_scores(ref.lower().split())
top_idx = int(scores.argmax())
return _rouge_l(corpus_reports[top_idx], ref, scorer)
def full_reward_breakdown(pred: str, ref: str,
corpus: list[str] | None = None) -> dict:
"""
IDENTICAL to inference.py XRayRewardEvaluator.breakdown().
contrastive uses BM25 baseline if corpus provided, else uses rouge_l directly.
"""
pred = pred.strip() or "no findings noted"
ref = ref.strip() or "no findings noted"
scorer, embed_tok, embed_model, nli_pipe = _get_judges()
rl = _rouge_l(pred, ref, scorer)
neg_safety = _negation_safety(pred, ref)
if corpus:
baseline = _baseline_rouge(ref, corpus, scorer)
delta = rl - baseline
contrastive = (max(-1.0, min(1.0, delta)) + 1.0) / 2.0
else:
contrastive = (max(-1.0, min(1.0, rl)) + 1.0) / 2.0
biomed_sim = _biomed_sim(pred, ref, embed_tok, embed_model)
nli_q = _nli_quality(pred, nli_pipe)
hf_judge = (biomed_sim + nli_q) / 2.0
total = 0.25 * contrastive + 0.25 * rl + 0.25 * neg_safety + 0.25 * hf_judge
return {
"contrastive": round(contrastive, 4),
"rouge_l": round(rl, 4),
"negation_safety": round(neg_safety, 4),
"hf_judge": round(hf_judge, 4),
"biomed_sim": round(biomed_sim, 4),
"nli_quality": round(nli_q, 4),
"total": round(total, 4),
}
def quick_reward(report: str) -> tuple[float, str]:
"""Fast proxy score β no reference or judge models needed."""
KEY = ["lung", "heart", "normal", "clear", "opacity", "infiltrate",
"cardiomegaly", "pleural", "pulmonary", "chest", "thorax",
"pneumonia", "edema", "effusion", "consolidation"]
rl = report.lower()
present = [t for t in KEY if t in rl]
words = len(report.split())
term_s = len(present) / len(KEY)
comp_s = min(1.0, words / 100.0)
struct_s = 1.0 if 50 <= words <= 150 else 0.5
score = max(0.0, min(1.0, term_s*0.4 + comp_s*0.3 + struct_s*0.3))
fb = (f"Reward Score: {score:.2f} | Medical Terminology: {term_s:.1%} | "
f"Clinical Completeness: {comp_s:.1%} | "
f"Report Structure: {struct_s:.1%}")
return score, fb
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# FastAPI
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
app = FastAPI(title="BioStack API")
app.add_middleware(CORSMiddleware, allow_origins=["*"],
allow_methods=["*"], allow_headers=["*"])
@app.get("/health")
def health():
return {"status": "ok", "device": str(DEVICE)}
@app.post("/sft")
async def sft_endpoint(file: UploadFile = File(...)):
try:
tensor = preprocess(await file.read())
report = generate_report(sft_model, tensor)
return {"report": report}
except Exception as e:
traceback.print_exc()
return {"report": f"ERROR: {e}"}
@app.post("/reward")
async def reward_endpoint(
file: UploadFile = File(...),
ground_truth: str = Form(default=""),
):
try:
tensor = preprocess(await file.read())
report = generate_report(sft_model, tensor)
if ground_truth.strip():
bd = full_reward_breakdown(report, ground_truth)
return {
"score": bd["total"],
"feedback": (
f"Reward Score: {bd['total']:.4f} | "
f"Contrastive: {bd['contrastive']:.4f} | "
f"ROUGE-L: {bd['rouge_l']:.4f} | "
f"Negation Safety: {bd['negation_safety']:.4f} | "
f"HF Judge: {bd['hf_judge']:.4f}"
),
"sft_report": report,
# ββ NEW: expose hf_judge directly for frontend comparison ββ
"hf_judge": bd["hf_judge"],
"breakdown": bd,
"has_breakdown": True,
}
else:
score, feedback = quick_reward(report)
return {
"score": score,
"feedback": feedback,
"sft_report": report,
# No ground truth β hf_judge not available
"hf_judge": None,
"breakdown": None,
"has_breakdown": False,
}
except Exception as e:
traceback.print_exc()
return {"score": 0.0, "feedback": f"ERROR: {e}",
"sft_report": "", "hf_judge": None,
"breakdown": None, "has_breakdown": False}
@app.post("/grpo_reward")
async def grpo_reward_endpoint(
file: UploadFile = File(...),
ground_truth: str = Form(default=""),
):
try:
tensor = preprocess(await file.read())
report = generate_report(grpo_model, tensor)
if ground_truth.strip():
bd = full_reward_breakdown(report, ground_truth)
return {
"report": report,
# ββ NEW: expose hf_judge directly for frontend comparison ββ
"hf_judge": bd["hf_judge"],
"breakdown": bd,
"has_breakdown": True,
}
return {
"report": report,
# No ground truth β hf_judge not available
"hf_judge": None,
"breakdown": None,
"has_breakdown": False,
}
except Exception as e:
traceback.print_exc()
return {"report": f"ERROR: {e}", "hf_judge": None,
"breakdown": None, "has_breakdown": False}
# Serve React build AFTER all API routes
if os.path.exists("build"):
app.mount("/", StaticFiles(directory="build", html=True), name="static")
print("React build mounted at /")
else:
print("WARNING: ./build not found")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860, reload=False) |