import os import re import traceback import torch import torch.nn as nn import numpy as np from PIL import Image from transformers import ( AutoConfig, AutoModel, # ← handles dinov2/ViT automatically AutoModelForSeq2SeqLM, AutoTokenizer ) from huggingface_hub import hf_hub_download import gradio as gr # ───────────────────────── Helpers: blocklist + text normalizer ───────────────────────── def build_bad_words_ids(tok: AutoTokenizer): """Build token id sequences to block anonymization artifacts.""" bad_phrases = [ "XXXX", "xxxx", "X-XXXX", "x-XXXX", "x - XXXX", "x -xxxx", "x-xxxx", "X - XXXX", "x—XXXX", "x–XXXX", "x — XXXX", "x – XXXX", "x - XXX", "x-XXX", "x- xx", "x - xx", ] bad_ids = [] for phrase in bad_phrases: ids = tok(phrase, add_special_tokens=False).input_ids if ids and not all(i == tok.unk_token_id for i in ids): bad_ids.append(ids) return bad_ids or None # HF expects None if empty def normalize_report(text: str) -> str: """Cleanup on generated text to replace/remove anonymization placeholders.""" if not text: return text text = re.sub(r'\bx\s*[-–—]?\s*xxxx\b', 'x-ray', text, flags=re.IGNORECASE) text = re.sub(r'\bxxxx\b', '', text, flags=re.IGNORECASE) text = re.sub(r'\bx\s*[-–—]\s*ray\b', 'x-ray', text, flags=re.IGNORECASE) text = re.sub(r'\s+', ' ', text).strip() text = re.sub(r'\s+([.,;:])', r'\1', text) return text # ─── 1) MODEL LOADING ───────────────────────────────────────────────────────── def load_model(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") repo_id = "RakeshNJ12345/MMic-CXR" base_path = "mimic_trained/final" # ← your confirmed path # 1a) Encoder (DINOv2/ViT via AutoModel so config decides the class) enc_cfg = AutoConfig.from_pretrained(repo_id, subfolder=f"{base_path}/vit") vit = AutoModel.from_pretrained(repo_id, subfolder=f"{base_path}/vit").to(device) # enc_cfg.model_type will be 'dinov2' in your case; vit.config.hidden_size is available. # 1b) Decoder & tokenizer (BioBART) dec = AutoModelForSeq2SeqLM.from_pretrained(repo_id, subfolder=f"{base_path}/decoder").to(device) tok = AutoTokenizer.from_pretrained(repo_id, subfolder=f"{base_path}/decoder") tok.clean_up_tokenization_spaces = True # 1c) Projection head proj_path = hf_hub_download(repo_id=repo_id, filename="proj.bin", subfolder=base_path) loaded = torch.load(proj_path, map_location=device) if isinstance(loaded, dict): sd = loaded.get("state_dict", loaded) proj = nn.Linear(vit.config.hidden_size, dec.config.d_model) proj.load_state_dict(sd) else: proj = loaded # if you saved an nn.Linear directly proj = proj.to(device) # 1d) Blocklist for anonymization artifacts bad_words_ids = build_bad_words_ids(tok) # 1e) Wrapper class TwoViewModel(nn.Module): def __init__(self, vit, dec, proj, tok, bad_words_ids=None): super().__init__() self.vit = vit self.dec = dec self.proj = proj self.tok = tok self.bad_words_ids = bad_words_ids def forward(self, *args, **kwargs): raise NotImplementedError def generate(self, img_f, img_l, finds, max_len=128, num_beams=4): device = img_f.device # CLS embeddings from both views out_f = self.vit(pixel_values=img_f).last_hidden_state[:, 0] out_l = self.vit(pixel_values=img_l).last_hidden_state[:, 0] # average + project → prefix embedding avg = 0.5 * (out_f + out_l) prefix = self.proj(avg).unsqueeze(1) # [B,1,D] # prepend findings text (optional) if (finds or "").strip(): enc = self.tok(finds, return_tensors="pt", padding=True, truncation=True).to(device) text_emb = self.dec.get_encoder().embed_tokens(enc.input_ids) enc_emb = torch.cat([prefix, text_emb], dim=1) mask = torch.cat([ torch.ones(prefix.size(0), 1, device=device, dtype=torch.long), enc.attention_mask ], dim=1) else: enc_emb = prefix mask = torch.ones(prefix.size(0), 1, device=device, dtype=torch.long) return self.dec.generate( inputs_embeds=enc_emb, attention_mask=mask, max_length=max_len, num_beams=num_beams, no_repeat_ngram_size=2, early_stopping=True, length_penalty=1.0, repetition_penalty=1.2, bad_words_ids=self.bad_words_ids, eos_token_id=self.tok.eos_token_id, ) model = TwoViewModel(vit, dec, proj, tok, bad_words_ids=bad_words_ids).to(device) return model, tok, device model, tokenizer, device = load_model() # ─── 2) PREPROCESS ───────────────────────────────────────────────────────────── def preprocess(img: Image.Image) -> torch.Tensor: # Basic resize + [0,1] scaling; works across ViT/DINOv2 img = img.convert("RGB").resize((224, 224)) arr = np.array(img).astype(np.float32) / 255.0 if arr.ndim == 2: arr = np.stack([arr]*3, axis=-1) t = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0) # [1,3,224,224] return t.to(device) # ─── 3) GENERATION ───────────────────────────────────────────────────────────── def generate_report(frontal, lateral, findings, beams, max_len): try: if frontal is None: return "⛔ Please upload a frontal X-ray." if lateral is None: lateral = frontal f_t = preprocess(frontal) l_t = preprocess(lateral) output_ids = model.generate( f_t, l_t, findings or "", max_len=max_len, num_beams=beams ) text = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip() text = normalize_report(text) # cleanup anonymization artifacts return text or "" except Exception as e: traceback.print_exc() return f"❌ Generation error: {repr(e)}" # ─── 4) GRADIO INTERFACE ─────────────────────────────────────────────────────── iface = gr.Interface( fn=generate_report, inputs=[ gr.Image(type="pil", label="Frontal X-Ray"), gr.Image(type="pil", label="Lateral X-Ray (optional)"), gr.Textbox(label="Findings (optional)"), gr.Slider(1, 8, value=4, step=1, label="Beam width"), gr.Slider(50, 256, value=128, step=1, label="Max report length"), ], outputs=gr.Textbox(label="🩺 Generated Impression"), title="Two-View Chest X-Ray Report Generator", allow_flagging="never", ) if __name__ == "__main__": iface.launch(server_name="0.0.0.0", server_port=7860, debug=True)