Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import traceback | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from PIL import Image | |
| from transformers import ( | |
| AutoConfig, AutoModel, # β handles dinov2/ViT automatically | |
| AutoModelForSeq2SeqLM, AutoTokenizer | |
| ) | |
| from huggingface_hub import hf_hub_download | |
| import gradio as gr | |
| # βββββββββββββββββββββββββ Helpers: blocklist + text normalizer βββββββββββββββββββββββββ | |
| def build_bad_words_ids(tok: AutoTokenizer): | |
| """Build token id sequences to block anonymization artifacts.""" | |
| bad_phrases = [ | |
| "XXXX", "xxxx", "X-XXXX", "x-XXXX", "x - XXXX", "x -xxxx", "x-xxxx", | |
| "X - XXXX", "xβXXXX", "xβXXXX", "x β XXXX", "x β XXXX", | |
| "x - XXX", "x-XXX", "x- xx", "x - xx", | |
| ] | |
| bad_ids = [] | |
| for phrase in bad_phrases: | |
| ids = tok(phrase, add_special_tokens=False).input_ids | |
| if ids and not all(i == tok.unk_token_id for i in ids): | |
| bad_ids.append(ids) | |
| return bad_ids or None # HF expects None if empty | |
| def normalize_report(text: str) -> str: | |
| """Cleanup on generated text to replace/remove anonymization placeholders.""" | |
| if not text: | |
| return text | |
| text = re.sub(r'\bx\s*[-ββ]?\s*xxxx\b', 'x-ray', text, flags=re.IGNORECASE) | |
| text = re.sub(r'\bxxxx\b', '', text, flags=re.IGNORECASE) | |
| text = re.sub(r'\bx\s*[-ββ]\s*ray\b', 'x-ray', text, flags=re.IGNORECASE) | |
| text = re.sub(r'\s+', ' ', text).strip() | |
| text = re.sub(r'\s+([.,;:])', r'\1', text) | |
| return text | |
| # βββ 1) MODEL LOADING βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_model(): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| repo_id = "RakeshNJ12345/MMic-CXR" | |
| base_path = "mimic_trained/final" # β your confirmed path | |
| # 1a) Encoder (DINOv2/ViT via AutoModel so config decides the class) | |
| enc_cfg = AutoConfig.from_pretrained(repo_id, subfolder=f"{base_path}/vit") | |
| vit = AutoModel.from_pretrained(repo_id, subfolder=f"{base_path}/vit").to(device) | |
| # enc_cfg.model_type will be 'dinov2' in your case; vit.config.hidden_size is available. | |
| # 1b) Decoder & tokenizer (BioBART) | |
| dec = AutoModelForSeq2SeqLM.from_pretrained(repo_id, subfolder=f"{base_path}/decoder").to(device) | |
| tok = AutoTokenizer.from_pretrained(repo_id, subfolder=f"{base_path}/decoder") | |
| tok.clean_up_tokenization_spaces = True | |
| # 1c) Projection head | |
| proj_path = hf_hub_download(repo_id=repo_id, filename="proj.bin", subfolder=base_path) | |
| loaded = torch.load(proj_path, map_location=device) | |
| if isinstance(loaded, dict): | |
| sd = loaded.get("state_dict", loaded) | |
| proj = nn.Linear(vit.config.hidden_size, dec.config.d_model) | |
| proj.load_state_dict(sd) | |
| else: | |
| proj = loaded # if you saved an nn.Linear directly | |
| proj = proj.to(device) | |
| # 1d) Blocklist for anonymization artifacts | |
| bad_words_ids = build_bad_words_ids(tok) | |
| # 1e) Wrapper | |
| class TwoViewModel(nn.Module): | |
| def __init__(self, vit, dec, proj, tok, bad_words_ids=None): | |
| super().__init__() | |
| self.vit = vit | |
| self.dec = dec | |
| self.proj = proj | |
| self.tok = tok | |
| self.bad_words_ids = bad_words_ids | |
| def forward(self, *args, **kwargs): | |
| raise NotImplementedError | |
| def generate(self, img_f, img_l, finds, max_len=128, num_beams=4): | |
| device = img_f.device | |
| # CLS embeddings from both views | |
| out_f = self.vit(pixel_values=img_f).last_hidden_state[:, 0] | |
| out_l = self.vit(pixel_values=img_l).last_hidden_state[:, 0] | |
| # average + project β prefix embedding | |
| avg = 0.5 * (out_f + out_l) | |
| prefix = self.proj(avg).unsqueeze(1) # [B,1,D] | |
| # prepend findings text (optional) | |
| if (finds or "").strip(): | |
| enc = self.tok(finds, return_tensors="pt", padding=True, truncation=True).to(device) | |
| text_emb = self.dec.get_encoder().embed_tokens(enc.input_ids) | |
| enc_emb = torch.cat([prefix, text_emb], dim=1) | |
| mask = torch.cat([ | |
| torch.ones(prefix.size(0), 1, device=device, dtype=torch.long), | |
| enc.attention_mask | |
| ], dim=1) | |
| else: | |
| enc_emb = prefix | |
| mask = torch.ones(prefix.size(0), 1, device=device, dtype=torch.long) | |
| return self.dec.generate( | |
| inputs_embeds=enc_emb, | |
| attention_mask=mask, | |
| max_length=max_len, | |
| num_beams=num_beams, | |
| no_repeat_ngram_size=2, | |
| early_stopping=True, | |
| length_penalty=1.0, | |
| repetition_penalty=1.2, | |
| bad_words_ids=self.bad_words_ids, | |
| eos_token_id=self.tok.eos_token_id, | |
| ) | |
| model = TwoViewModel(vit, dec, proj, tok, bad_words_ids=bad_words_ids).to(device) | |
| return model, tok, device | |
| model, tokenizer, device = load_model() | |
| # βββ 2) PREPROCESS βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def preprocess(img: Image.Image) -> torch.Tensor: | |
| # Basic resize + [0,1] scaling; works across ViT/DINOv2 | |
| img = img.convert("RGB").resize((224, 224)) | |
| arr = np.array(img).astype(np.float32) / 255.0 | |
| if arr.ndim == 2: | |
| arr = np.stack([arr]*3, axis=-1) | |
| t = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0) # [1,3,224,224] | |
| return t.to(device) | |
| # βββ 3) GENERATION βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def generate_report(frontal, lateral, findings, beams, max_len): | |
| try: | |
| if frontal is None: | |
| return "β Please upload a frontal X-ray." | |
| if lateral is None: | |
| lateral = frontal | |
| f_t = preprocess(frontal) | |
| l_t = preprocess(lateral) | |
| output_ids = model.generate( | |
| f_t, l_t, findings or "", max_len=max_len, num_beams=beams | |
| ) | |
| text = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip() | |
| text = normalize_report(text) # cleanup anonymization artifacts | |
| return text or "<empty>" | |
| except Exception as e: | |
| traceback.print_exc() | |
| return f"β Generation error: {repr(e)}" | |
| # βββ 4) GRADIO INTERFACE βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| iface = gr.Interface( | |
| fn=generate_report, | |
| inputs=[ | |
| gr.Image(type="pil", label="Frontal X-Ray"), | |
| gr.Image(type="pil", label="Lateral X-Ray (optional)"), | |
| gr.Textbox(label="Findings (optional)"), | |
| gr.Slider(1, 8, value=4, step=1, label="Beam width"), | |
| gr.Slider(50, 256, value=128, step=1, label="Max report length"), | |
| ], | |
| outputs=gr.Textbox(label="π©Ί Generated Impression"), | |
| title="Two-View Chest X-Ray Report Generator", | |
| allow_flagging="never", | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch(server_name="0.0.0.0", server_port=7860, debug=True) | |