Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from typing import Dict, Any | |
| from PIL import Image | |
| import open_clip | |
| from transformers import ( | |
| BioGptTokenizer, | |
| BioGptForCausalLM, | |
| AutoTokenizer, | |
| AutoModelForSeq2SeqLM | |
| ) | |
| import gradio as gr | |
| # NOTE: Ensure this library is installed on the Hugging Face Space | |
| from IndicTransToolkit import IndicProcessor | |
| from huggingface_hub import hf_hub_download # New import for HF deployment | |
| # --- 1. CONFIGURATION (Stage 1: Report Generation) --- | |
| # NOTE: Update this REPO_ID to the actual Hugging Face repository where you upload your .pth files! | |
| REPO_ID = "Robinhood135/biogptm1" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # --- MODEL/DECODING PARAMS --- | |
| BIOMEDCLIP_MODEL_NAME = 'hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224' | |
| CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073) | |
| CLIP_STD = (0.26862954, 0.26130258, 0.27577711) | |
| PREFIX_LENGTH = 10 | |
| PROMPT_TEXT = "You are a Radiologist.The chest image findings are:" | |
| # --- BEST DECODING STRATEGY (Beam Search) --- | |
| BEST_STRATEGY_PARAMS = { | |
| "num_beams": 4, | |
| "do_sample": False, | |
| "repetition_penalty": 1.2, | |
| "max_new_tokens": 100, | |
| "min_new_tokens": 10, | |
| } | |
| # --- 2. MODEL CLASS (Stage 1) - Kept the same --- | |
| def freeze_module(module: nn.Module): | |
| for param in module.parameters(): param.requires_grad = False | |
| class BiomedCLIPBioGPTGenerator(nn.Module): | |
| def __init__(self, tokenizer, model_name=BIOMEDCLIP_MODEL_NAME, prefix_length=PREFIX_LENGTH): | |
| super().__init__() | |
| self.tokenizer = tokenizer | |
| self.prefix_length = prefix_length | |
| self.clip_model, _, _ = open_clip.create_model_and_transforms(model_name) | |
| # Handle cases where image encoder is visual or a direct method | |
| self.image_encoder = self.clip_model.visual if hasattr(self.clip_model, 'visual') else self.clip_model.encode_image | |
| freeze_module(self.image_encoder) | |
| with torch.no_grad(): | |
| dummy_features = self.image_encoder(torch.randn(1, 3, 224, 224)) | |
| if isinstance(dummy_features, tuple): dummy_features = dummy_features[0] | |
| self.embed_dim = dummy_features.shape[-1] | |
| config = BioGptForCausalLM.from_pretrained('microsoft/biogpt').config | |
| self.biogpt = BioGptForCausalLM.from_pretrained('microsoft/biogpt', config=config) | |
| self.biogpt.resize_token_embeddings(len(self.tokenizer)) | |
| self.gpt_hidden_dim = self.biogpt.config.hidden_size | |
| self.biogpt.config.pad_token_id = self.tokenizer.pad_token_id | |
| self.projection_head = nn.Sequential( | |
| nn.Linear(self.embed_dim, self.prefix_length * self.gpt_hidden_dim), | |
| nn.Tanh(), | |
| nn.Linear(self.prefix_length * self.gpt_hidden_dim, self.prefix_length * self.gpt_hidden_dim) | |
| ) | |
| def get_prefix_embeddings(self, images): | |
| clip_features = self.image_encoder(images).float() | |
| prefix_embeds = self.projection_head(clip_features) | |
| return prefix_embeds.view(-1, self.prefix_length, self.gpt_hidden_dim) | |
| def get_text_embeddings(self, input_ids): | |
| return self.biogpt.get_input_embeddings()(input_ids) | |
| # --- 3. INFERENCE FUNCTION (Stage 1) - Kept the same --- | |
| def generate_report(model, pil_image: Image.Image, method_params: Dict[str, Any]): | |
| model.eval() | |
| # 3.1 Apply image transformation | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD) | |
| ]) | |
| image_tensor = transform(pil_image.convert('RGB')).unsqueeze(0).to(device) | |
| # 3.2 Get prefix embeddings | |
| prefix_embeds = model.get_prefix_embeddings(image_tensor) | |
| # 3.3 Encode prompt text | |
| prompt_data = model.tokenizer(PROMPT_TEXT, return_tensors="pt").to(device) | |
| prompt_embeds = model.get_text_embeddings(prompt_data["input_ids"]) | |
| combined_embeds = torch.cat([prefix_embeds, prompt_embeds], dim=1) | |
| prefix_att_mask = torch.ones(prefix_embeds.shape[:2], dtype=torch.long, device=device) | |
| combined_att_mask = torch.cat([prefix_att_mask, prompt_data["attention_mask"]], dim=1) | |
| # 3.4 Generation parameters | |
| generation_args = { | |
| "inputs_embeds": combined_embeds, | |
| "attention_mask": combined_att_mask, | |
| "pad_token_id": model.tokenizer.pad_token_id, | |
| "eos_token_id": model.tokenizer.eos_token_id, | |
| "use_cache": True, | |
| } | |
| generation_args.update(method_params) | |
| # 3.5 Generate | |
| generated_ids = model.biogpt.generate(**generation_args) | |
| # 3.6 Decode and clean | |
| full_text = model.tokenizer.decode(generated_ids[0], skip_special_tokens=True) | |
| if full_text.startswith(PROMPT_TEXT): | |
| text = full_text[len(PROMPT_TEXT):].strip() | |
| else: | |
| text = full_text | |
| return text if text.strip() else "[BLANK/FAILED GENERATION]" | |
| # --- 4. MODEL LOADING (Stage 1) - MODIFIED FOR HF HUB --- | |
| def load_trained_generator(): | |
| print(f"Loading Report Generator model from {REPO_ID}...") | |
| # Load from Hugging Face Hub | |
| try: | |
| clip_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="biomedclipp.pth") | |
| gpt_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="biogptt.pth") | |
| proj_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="projectorr.pth") | |
| except Exception as e: | |
| raise FileNotFoundError(f"Failed to download one or more checkpoint files from {REPO_ID}. Error: {e}") | |
| # Initialize tokenizer | |
| base_tokenizer = BioGptTokenizer.from_pretrained('microsoft/biogpt') | |
| if base_tokenizer.pad_token is None: | |
| base_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) | |
| # Initialize model | |
| model = BiomedCLIPBioGPTGenerator(base_tokenizer).to(device) | |
| # Load CLIP encoder | |
| clip_checkpoint = torch.load(clip_ckpt_path, map_location=device) | |
| state_dict = clip_checkpoint.get('model_state_dict', clip_checkpoint.get('state_dict', clip_checkpoint)) | |
| # Filter state dict for the visual encoder and clean keys | |
| visual_state = {k.replace('model.visual.', '').replace('visual.', ''): v for k, v in state_dict.items() if 'visual' in k} | |
| model.image_encoder.load_state_dict(visual_state, strict=False) | |
| # Load trained BioGPT and Projection weights | |
| model.biogpt.load_state_dict(torch.load(gpt_ckpt_path, map_location=device)) | |
| model.projection_head.load_state_dict(torch.load(proj_ckpt_path, map_location=device)) | |
| model.eval() | |
| print("✅ Report Generator loaded successfully.") | |
| return model | |
| # --- 5. MODEL LOADING (Stage 2: Translation) - Kept the same --- | |
| def load_translator(): | |
| # IndicTrans2 models are typically loaded directly from their HF repos (ai4bharat/...) | |
| print("Loading Translation model (IndicTrans2)...") | |
| try: | |
| # IndicTransToolkit library is assumed to be installed | |
| ip = IndicProcessor(inference=True) | |
| model_name = "ai4bharat/indictrans2-en-indic-dist-200M" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| # Note: If memory is an issue on the Space, you might need to use a smaller model or lower precision. | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True).to(device) | |
| print("✅ Translator loaded successfully.") | |
| return ip, tokenizer, model | |
| except Exception as e: | |
| print(f"Error loading translation model: {e}") | |
| # Return dummy values if loading fails to prevent crash | |
| return None, None, None | |
| # Load models globally | |
| GENERATOR_MODEL = load_trained_generator() | |
| IP, TRANS_TOKENIZER, TRANS_MODEL = load_translator() | |
| # --- 6. TRANSLATION FUNCTION (Stage 2) - Kept the same --- | |
| def translate_report(english_text: str, target_lang: str = "hin_Deva") -> str: | |
| if TRANS_MODEL is None or not english_text: | |
| return "[Translation Model Not Available or No Text to Translate]" | |
| # 6.1 Preprocessing | |
| batch = IP.preprocess_batch([english_text], src_lang="eng_Latn", tgt_lang=target_lang, visualize=False) | |
| batch = TRANS_TOKENIZER(batch, padding="longest", truncation=True, max_length=256, return_tensors="pt").to(device) | |
| # 6.2 Generation | |
| outputs = TRANS_MODEL.generate(**batch, num_beams=5, num_return_sequences=1, max_length=256, use_cache=False) | |
| # 6.3 Postprocessing | |
| outputs = TRANS_TOKENIZER.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
| translated_text = IP.postprocess_batch(outputs, lang=target_lang)[0] | |
| return translated_text | |
| # --- 7. GRADIO WRAPPER FUNCTION (Simplified) - Kept the same --- | |
| def inference_wrapper(input_image: Image.Image): | |
| if input_image is None: | |
| return "Please upload a chest X-ray image.", "[No English Report]" | |
| # STAGE 1: GENERATE RAW ENGLISH REPORT | |
| try: | |
| raw_english_report = generate_report(GENERATOR_MODEL, input_image, BEST_STRATEGY_PARAMS) | |
| except Exception as e: | |
| raw_english_report = f"An error occurred during generation: {e}" | |
| return raw_english_report, "[Translation Skipped]" | |
| # STAGE 2: TRANSLATE RAW ENGLISH REPORT | |
| try: | |
| hindi_report = translate_report(raw_english_report, target_lang="hin_Deva") | |
| except Exception as e: | |
| hindi_report = f"[Translation failed: {e}]" | |
| return raw_english_report, hindi_report | |
| # --- 8. GRADIO INTERFACE SETUP --- | |
| if __name__ == "__main__": | |
| # Define example image filenames | |
| EXAMPLE_FILENAMES = [ | |
| "001c3589-7aed3964-f06ba8d5-03882592-d77f222c.jpg", | |
| "004438db-4a5d6ab3-acc6c408-5dce0934-7d30b269.jpg", | |
| "0006f2ea-d44c6b5e-aeea6fd2-a974657c-90a39211.jpg", | |
| "0008ba07-4e43d6f4-fc692a96-c18a27a8-10eea0cd.jpg", | |
| "001526e1-0d0b8a2d-87e74f7e-72646210-c635fee4.jpg", | |
| "00438e51-4f75714b-943c8edd-6740491f-f8307602.jpg", | |
| "001c78df-8ce750bd-c100a8e0-2874ea0e-09cdbd4e.jpg", | |
| "000b9235-69b5b7e2-1ec32996-50f79b97-46f939cf.jpg", | |
| # "0041603e-059f400f-c509c746-0da5c413-ee889ec1.jpg", | |
| "001198e2-a2adcc23-7253eb78-0dcb5eaa-b10ed183.jpg", | |
| "0003fc7c-3dfce751-9ff36dc3-8fa4f6d9-0515ce50.jpg", | |
| "0018ff6b-8ad1196f-823030d0-1141b667-2a1a117a.jpg", | |
| "00068d26-8d583659-af7de1da-fc6c0476-d94aada1.jpg", | |
| "00196af8-50d17b31-b1b5a7be-da90b7e6-fd3a8004.jpg", | |
| "004017bd-6506697c-3ead0e70-548114b7-2af62447.jpg", | |
| "00059571-ade80b6c-7931ddb8-b486c6c1-1e543b22.jpg", | |
| "00419c98-6f4860a1-3dee986d-8e2ceadc-d2fd30ae.jpg", | |
| "000ffbff-3d93bcef-da8b17cd-fbcede53-51728df9.jpg", | |
| "0016e39b-d0cad5f2-eecb7ae8-4db8b8f2-0b366f1a.jpg", | |
| "00469c3d-4ebf8374-055428f7-d798daca-3e37d354.jpg", | |
| "0013ac79-5eea664c-7ef52c71-7e5a25f3-013715fc.jpg" | |
| ] | |
| # Create examples list with only image paths | |
| examples = [[os.path.join("examples", f)] for f in EXAMPLE_FILENAMES] | |
| # Interface components | |
| input_image = gr.Image(type="pil", label="Upload Chest X-ray Image") | |
| output_en = gr.Textbox(label="Generated Radiology Report (English)", lines=5) | |
| output_hi = gr.Textbox(label="Generated Radiology Report (Hindi/हिन्दी)", lines=5) | |
| # Gradio app setup | |
| app = gr.Interface( | |
| fn=inference_wrapper, | |
| inputs=input_image, | |
| outputs=[output_en, output_hi], | |
| title="🔬Radiology Report Generation from Chest X-rays in Indic Language (Hindi)", | |
| description="Upload a chest X-ray image to generate a radiology finding in English and automatically translate it to Hindi.", | |
| article="<div style='text-align: center; margin-top: 20px;'>⚠️ <em>This system is intended solely for research, demonstration, and educational purposes. All findings must be reviewed and interpreted by a qualified healthcare professional.</em></div>", | |
| # allow_flagging="never", | |
| examples=examples, | |
| cache_examples=False | |
| # cache_examples=True | |
| ) | |
| print("\nStarting Gradio interface...") | |
| app.launch() # Removed share=True for typical Hugging Face Space deployment |