abc / app.py
RakeshNJ12345's picture
Update app.py
0bdd3ac verified
import os
import re
import traceback
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
from transformers import (
AutoConfig, AutoModel, # ← handles dinov2/ViT automatically
AutoModelForSeq2SeqLM, AutoTokenizer
)
from huggingface_hub import hf_hub_download
import gradio as gr
# ───────────────────────── Helpers: blocklist + text normalizer ─────────────────────────
def build_bad_words_ids(tok: AutoTokenizer):
"""Build token id sequences to block anonymization artifacts."""
bad_phrases = [
"XXXX", "xxxx", "X-XXXX", "x-XXXX", "x - XXXX", "x -xxxx", "x-xxxx",
"X - XXXX", "xβ€”XXXX", "x–XXXX", "x β€” XXXX", "x – XXXX",
"x - XXX", "x-XXX", "x- xx", "x - xx",
]
bad_ids = []
for phrase in bad_phrases:
ids = tok(phrase, add_special_tokens=False).input_ids
if ids and not all(i == tok.unk_token_id for i in ids):
bad_ids.append(ids)
return bad_ids or None # HF expects None if empty
def normalize_report(text: str) -> str:
"""Cleanup on generated text to replace/remove anonymization placeholders."""
if not text:
return text
text = re.sub(r'\bx\s*[-–—]?\s*xxxx\b', 'x-ray', text, flags=re.IGNORECASE)
text = re.sub(r'\bxxxx\b', '', text, flags=re.IGNORECASE)
text = re.sub(r'\bx\s*[-–—]\s*ray\b', 'x-ray', text, flags=re.IGNORECASE)
text = re.sub(r'\s+', ' ', text).strip()
text = re.sub(r'\s+([.,;:])', r'\1', text)
return text
# ─── 1) MODEL LOADING ─────────────────────────────────────────────────────────
def load_model():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
repo_id = "RakeshNJ12345/MMic-CXR"
base_path = "mimic_trained/final" # ← your confirmed path
# 1a) Encoder (DINOv2/ViT via AutoModel so config decides the class)
enc_cfg = AutoConfig.from_pretrained(repo_id, subfolder=f"{base_path}/vit")
vit = AutoModel.from_pretrained(repo_id, subfolder=f"{base_path}/vit").to(device)
# enc_cfg.model_type will be 'dinov2' in your case; vit.config.hidden_size is available.
# 1b) Decoder & tokenizer (BioBART)
dec = AutoModelForSeq2SeqLM.from_pretrained(repo_id, subfolder=f"{base_path}/decoder").to(device)
tok = AutoTokenizer.from_pretrained(repo_id, subfolder=f"{base_path}/decoder")
tok.clean_up_tokenization_spaces = True
# 1c) Projection head
proj_path = hf_hub_download(repo_id=repo_id, filename="proj.bin", subfolder=base_path)
loaded = torch.load(proj_path, map_location=device)
if isinstance(loaded, dict):
sd = loaded.get("state_dict", loaded)
proj = nn.Linear(vit.config.hidden_size, dec.config.d_model)
proj.load_state_dict(sd)
else:
proj = loaded # if you saved an nn.Linear directly
proj = proj.to(device)
# 1d) Blocklist for anonymization artifacts
bad_words_ids = build_bad_words_ids(tok)
# 1e) Wrapper
class TwoViewModel(nn.Module):
def __init__(self, vit, dec, proj, tok, bad_words_ids=None):
super().__init__()
self.vit = vit
self.dec = dec
self.proj = proj
self.tok = tok
self.bad_words_ids = bad_words_ids
def forward(self, *args, **kwargs):
raise NotImplementedError
def generate(self, img_f, img_l, finds, max_len=128, num_beams=4):
device = img_f.device
# CLS embeddings from both views
out_f = self.vit(pixel_values=img_f).last_hidden_state[:, 0]
out_l = self.vit(pixel_values=img_l).last_hidden_state[:, 0]
# average + project β†’ prefix embedding
avg = 0.5 * (out_f + out_l)
prefix = self.proj(avg).unsqueeze(1) # [B,1,D]
# prepend findings text (optional)
if (finds or "").strip():
enc = self.tok(finds, return_tensors="pt", padding=True, truncation=True).to(device)
text_emb = self.dec.get_encoder().embed_tokens(enc.input_ids)
enc_emb = torch.cat([prefix, text_emb], dim=1)
mask = torch.cat([
torch.ones(prefix.size(0), 1, device=device, dtype=torch.long),
enc.attention_mask
], dim=1)
else:
enc_emb = prefix
mask = torch.ones(prefix.size(0), 1, device=device, dtype=torch.long)
return self.dec.generate(
inputs_embeds=enc_emb,
attention_mask=mask,
max_length=max_len,
num_beams=num_beams,
no_repeat_ngram_size=2,
early_stopping=True,
length_penalty=1.0,
repetition_penalty=1.2,
bad_words_ids=self.bad_words_ids,
eos_token_id=self.tok.eos_token_id,
)
model = TwoViewModel(vit, dec, proj, tok, bad_words_ids=bad_words_ids).to(device)
return model, tok, device
model, tokenizer, device = load_model()
# ─── 2) PREPROCESS ─────────────────────────────────────────────────────────────
def preprocess(img: Image.Image) -> torch.Tensor:
# Basic resize + [0,1] scaling; works across ViT/DINOv2
img = img.convert("RGB").resize((224, 224))
arr = np.array(img).astype(np.float32) / 255.0
if arr.ndim == 2:
arr = np.stack([arr]*3, axis=-1)
t = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0) # [1,3,224,224]
return t.to(device)
# ─── 3) GENERATION ─────────────────────────────────────────────────────────────
def generate_report(frontal, lateral, findings, beams, max_len):
try:
if frontal is None:
return "β›” Please upload a frontal X-ray."
if lateral is None:
lateral = frontal
f_t = preprocess(frontal)
l_t = preprocess(lateral)
output_ids = model.generate(
f_t, l_t, findings or "", max_len=max_len, num_beams=beams
)
text = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
text = normalize_report(text) # cleanup anonymization artifacts
return text or "<empty>"
except Exception as e:
traceback.print_exc()
return f"❌ Generation error: {repr(e)}"
# ─── 4) GRADIO INTERFACE ───────────────────────────────────────────────────────
iface = gr.Interface(
fn=generate_report,
inputs=[
gr.Image(type="pil", label="Frontal X-Ray"),
gr.Image(type="pil", label="Lateral X-Ray (optional)"),
gr.Textbox(label="Findings (optional)"),
gr.Slider(1, 8, value=4, step=1, label="Beam width"),
gr.Slider(50, 256, value=128, step=1, label="Max report length"),
],
outputs=gr.Textbox(label="🩺 Generated Impression"),
title="Two-View Chest X-Ray Report Generator",
allow_flagging="never",
)
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860, debug=True)