| import io |
| import torch |
| import torch.nn as nn |
| import timm |
| import traceback |
| import os |
| from PIL import Image |
| from fastapi import FastAPI, File, UploadFile |
| from fastapi.middleware.cors import CORSMiddleware |
| from torchvision import transforms |
| from transformers import T5ForConditionalGeneration, T5Tokenizer |
| from huggingface_hub import hf_hub_download |
|
|
| |
| |
| |
| CONFIG = { |
| 'coatnet_model': 'coatnet_1_rw_224', |
| 't5_model': 't5-small', |
| 'img_emb_dim': 768, |
| 'train_last_stages': 2, |
| 'image_size': 224, |
| 'max_length': 100, |
| 'num_beams': 4, |
| } |
|
|
| |
| |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"π₯οΈ Using device: {device}") |
|
|
| |
| |
| |
| print("\n" + "="*80) |
| print("LOADING TOKENIZER") |
| print("="*80) |
| tokenizer = T5Tokenizer.from_pretrained(CONFIG['t5_model']) |
| print(f"β Loaded tokenizer: {CONFIG['t5_model']}") |
|
|
| |
| |
| |
| transform = transforms.Compose([ |
| transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])), |
| transforms.ToTensor(), |
| transforms.Normalize( |
| mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225] |
| ) |
| ]) |
| print(f"β Image transform defined (size: {CONFIG['image_size']}x{CONFIG['image_size']})") |
|
|
| |
| |
| |
| class CoAtNetEncoder(nn.Module): |
| def __init__(self, model_name="coatnet_1_rw_224", pretrained=True, train_last_stages=2): |
| super().__init__() |
| self.encoder = timm.create_model( |
| model_name, |
| pretrained=pretrained, |
| num_classes=0, |
| global_pool="avg" |
| ) |
|
|
| |
| for p in self.encoder.parameters(): |
| p.requires_grad = False |
|
|
| |
| if hasattr(self.encoder, "stages") and train_last_stages is not None: |
| stages = self.encoder.stages |
| for stage in stages[-train_last_stages:]: |
| for p in stage.parameters(): |
| p.requires_grad = True |
|
|
| def forward(self, x): |
| return self.encoder(x) |
|
|
|
|
| |
| |
| |
| class VisionT5Model(nn.Module): |
| def __init__(self, img_encoder, txt_model_name="t5-small", img_emb_dim=768): |
| super().__init__() |
|
|
| |
| self.img_encoder = img_encoder |
|
|
| |
| self.t5 = T5ForConditionalGeneration.from_pretrained(txt_model_name) |
|
|
| |
| self.proj = nn.Linear(img_emb_dim, self.t5.config.d_model) |
|
|
| |
| for p in self.t5.shared.parameters(): |
| p.requires_grad = False |
|
|
| def forward(self, pixel_values, input_ids, attention_mask, labels=None): |
| |
| img_feats = self.img_encoder(pixel_values) |
|
|
| |
| img_feats = self.proj(img_feats) |
|
|
| |
| encoder_hidden_states = img_feats.unsqueeze(1) |
|
|
| |
| encoder_outputs = self.t5.encoder( |
| inputs_embeds=encoder_hidden_states |
| ) |
|
|
| |
| outputs = self.t5( |
| encoder_outputs=encoder_outputs, |
| attention_mask=torch.ones( |
| encoder_hidden_states.size()[:2], device=device |
| ), |
| input_ids=input_ids, |
| labels=labels, |
| ) |
| return outputs |
|
|
| def generate_reports(self, pixel_values, max_length=100, num_beams=4): |
| """ |
| Generate reports - EXACTLY matching Colab SECTION 6 |
| """ |
| |
| img_feats = self.img_encoder(pixel_values) |
| img_feats = self.proj(img_feats) |
| encoder_hidden_states = img_feats.unsqueeze(1) |
|
|
| |
| encoder_outputs = self.t5.encoder( |
| inputs_embeds=encoder_hidden_states |
| ) |
|
|
| |
| generated_ids = self.t5.generate( |
| encoder_outputs=encoder_outputs, |
| attention_mask=torch.ones( |
| encoder_hidden_states.size()[:2], device=device |
| ), |
| max_length=max_length, |
| num_beams=num_beams, |
| early_stopping=True |
| ) |
|
|
| return generated_ids |
|
|
|
|
| print("β Model architecture classes defined") |
|
|
| |
| |
| |
| def load_model_from_checkpoint(checkpoint_path: str, model_name: str, config: dict): |
| """ |
| Load VisionT5Model from checkpoint - EXACT implementation from Colab |
| """ |
| print(f"\nLoading {model_name} model...") |
| print(f" Checkpoint: {checkpoint_path}") |
|
|
| try: |
| |
| print(f" Creating CoAtNet encoder: {config['coatnet_model']}") |
| img_encoder = CoAtNetEncoder( |
| model_name=config['coatnet_model'], |
| pretrained=False, |
| train_last_stages=config['train_last_stages'] |
| ) |
|
|
| |
| print(f" Creating VisionT5 model with T5: {config['t5_model']}") |
| model = VisionT5Model( |
| img_encoder=img_encoder, |
| txt_model_name=config['t5_model'], |
| img_emb_dim=config['img_emb_dim'] |
| ) |
|
|
| |
| print(f" Loading checkpoint weights...") |
| checkpoint = torch.load(checkpoint_path, map_location=device) |
|
|
| |
| if isinstance(checkpoint, dict): |
| if 'model_state_dict' in checkpoint: |
| state_dict = checkpoint['model_state_dict'] |
| print(f" Found 'model_state_dict' in checkpoint") |
| elif 'state_dict' in checkpoint: |
| state_dict = checkpoint['state_dict'] |
| print(f" Found 'state_dict' in checkpoint") |
| elif 'model' in checkpoint: |
| state_dict = checkpoint['model'] |
| print(f" Found 'model' in checkpoint") |
| else: |
| |
| state_dict = checkpoint |
| print(f" Using checkpoint as state_dict directly") |
|
|
| |
| if 'epoch' in checkpoint: |
| print(f" Checkpoint epoch: {checkpoint['epoch']}") |
| if 'loss' in checkpoint: |
| print(f" Checkpoint loss: {checkpoint['loss']:.4f}") |
| else: |
| state_dict = checkpoint |
| print(f" Checkpoint is a state_dict") |
|
|
| |
| missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) |
|
|
| if missing_keys: |
| print(f" β οΈ Missing keys: {len(missing_keys)}") |
| if len(missing_keys) <= 5: |
| for key in missing_keys: |
| print(f" - {key}") |
|
|
| if unexpected_keys: |
| print(f" β οΈ Unexpected keys: {len(unexpected_keys)}") |
| if len(unexpected_keys) <= 5: |
| for key in unexpected_keys: |
| print(f" - {key}") |
|
|
| |
| model = model.to(device) |
| model.eval() |
|
|
| print(f"β {model_name} model loaded successfully!") |
| return model |
|
|
| except Exception as e: |
| print(f"β Error loading {model_name} model: {str(e)}") |
| import traceback |
| traceback.print_exc() |
| raise |
|
|
|
|
| |
| |
| |
| def generate_report( |
| image_path: str, |
| model: VisionT5Model, |
| config: dict |
| ) -> str: |
| """ |
| Generate medical report from X-ray image - EXACT implementation from Colab |
| """ |
| try: |
| |
| image = Image.open(image_path).convert('RGB') |
| pixel_values = transform(image).unsqueeze(0).to(device) |
|
|
| |
| with torch.no_grad(): |
| generated_ids = model.generate_reports( |
| pixel_values, |
| max_length=config['max_length'], |
| num_beams=config['num_beams'] |
| ) |
|
|
| |
| report = tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
|
|
| return report.strip() |
|
|
| except Exception as e: |
| print(f"Error generating report for {image_path}: {str(e)}") |
| return "" |
|
|
|
|
| |
| |
| |
| print("\n" + "="*80) |
| print("LOADING MODELS FROM HUGGINGFACE") |
| print("="*80) |
|
|
| |
| try: |
| SFT_MODEL_PATH = hf_hub_download( |
| repo_id="vinaykumarhs2020/RLHF_radiology_model", |
| filename="best_model.pt" |
| ) |
| PPO_MODEL_PATH = hf_hub_download( |
| repo_id="vinaykumarhs2020/RLHF_radiology_model", |
| filename="rlhf_model.pt" |
| ) |
| print(f"β Downloaded SFT model: {SFT_MODEL_PATH}") |
| print(f"β Downloaded PPO model: {PPO_MODEL_PATH}") |
| except Exception as e: |
| print(f"β Error downloading models: {e}") |
| |
| SFT_MODEL_PATH = "/content/best_model.pt" |
| PPO_MODEL_PATH = "/content/rlhf_model.pt" |
| print(f"β οΈ Using local paths instead") |
|
|
| |
| print("\n" + "="*80) |
| print("LOADING MODELS") |
| print("="*80) |
|
|
| sft_model = load_model_from_checkpoint( |
| SFT_MODEL_PATH, |
| "SFT", |
| CONFIG |
| ) |
|
|
| ppo_model = load_model_from_checkpoint( |
| PPO_MODEL_PATH, |
| "PPO", |
| CONFIG |
| ) |
|
|
| print("\nβ Both models loaded successfully!") |
|
|
| |
| |
| |
| app = FastAPI(title="Medical Report Generation - Matching Colab") |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
|
|
| def preprocess_bytes(file_bytes: bytes) -> torch.Tensor: |
| """Preprocess image bytes for inference""" |
| img = Image.open(io.BytesIO(file_bytes)).convert("RGB") |
| return transform(img).unsqueeze(0).to(device) |
|
|
|
|
| @app.get("/health") |
| def health(): |
| return { |
| "status": "ok", |
| "device": str(device), |
| "models_loaded": True, |
| "config": CONFIG |
| } |
|
|
|
|
| @app.post("/sft") |
| async def sft_inference(file: UploadFile = File(...)): |
| """ |
| SFT model inference - EXACTLY matching Colab behavior |
| """ |
| try: |
| |
| tensor = preprocess_bytes(await file.read()) |
| |
| |
| with torch.no_grad(): |
| generated_ids = sft_model.generate_reports( |
| tensor, |
| max_length=CONFIG['max_length'], |
| num_beams=CONFIG['num_beams'] |
| ) |
| |
| |
| report = tokenizer.decode(generated_ids[0], skip_special_tokens=True).strip() |
| |
| print(f"[SFT] Generated: {report}") |
| |
| |
| return {"report": report, "model": "SFT", "config_used": CONFIG} |
| |
| except Exception as e: |
| traceback.print_exc() |
| return {"report": f"ERROR: {str(e)}", "model": "SFT"} |
|
|
|
|
| @app.post("/ppo") |
| async def ppo_inference(file: UploadFile = File(...)): |
| """ |
| PPO model inference - EXACTLY matching Colab behavior |
| """ |
| try: |
| |
| tensor = preprocess_bytes(await file.read()) |
| |
| |
| with torch.no_grad(): |
| generated_ids = ppo_model.generate_reports( |
| tensor, |
| max_length=CONFIG['max_length'], |
| num_beams=CONFIG['num_beams'] |
| ) |
| |
| |
| report = tokenizer.decode(generated_ids[0], skip_special_tokens=True).strip() |
| |
| print(f"[PPO] Generated: {report}") |
| |
| |
| return {"report": report, "model": "PPO", "config_used": CONFIG} |
| |
| except Exception as e: |
| traceback.print_exc() |
| return {"report": f"ERROR: {str(e)}", "model": "PPO"} |
|
|
|
|
| @app.post("/compare") |
| async def compare_models(file: UploadFile = File(...)): |
| """ |
| Generate reports from both models for comparison |
| """ |
| try: |
| file_bytes = await file.read() |
| tensor = preprocess_bytes(file_bytes) |
| |
| |
| with torch.no_grad(): |
| sft_ids = sft_model.generate_reports( |
| tensor, |
| max_length=CONFIG['max_length'], |
| num_beams=CONFIG['num_beams'] |
| ) |
| sft_report = tokenizer.decode(sft_ids[0], skip_special_tokens=True).strip() |
| |
| |
| with torch.no_grad(): |
| ppo_ids = ppo_model.generate_reports( |
| tensor, |
| max_length=CONFIG['max_length'], |
| num_beams=CONFIG['num_beams'] |
| ) |
| ppo_report = tokenizer.decode(ppo_ids[0], skip_special_tokens=True).strip() |
| |
| print(f"[COMPARE] SFT: {sft_report}") |
| print(f"[COMPARE] PPO: {ppo_report}") |
| |
| return { |
| "sft_report": sft_report, |
| "ppo_report": ppo_report, |
| "config_used": CONFIG |
| } |
| |
| except Exception as e: |
| traceback.print_exc() |
| return { |
| "sft_report": f"ERROR: {str(e)}", |
| "ppo_report": f"ERROR: {str(e)}" |
| } |
|
|
|
|
| @app.get("/debug_config") |
| def debug_config(): |
| """Debug endpoint to check configuration""" |
| return { |
| "config": CONFIG, |
| "device": str(device), |
| "tokenizer": CONFIG['t5_model'], |
| "image_size": CONFIG['image_size'], |
| "max_length": CONFIG['max_length'], |
| "num_beams": CONFIG['num_beams'], |
| "models_loaded": { |
| "sft": sft_model is not None, |
| "ppo": ppo_model is not None |
| } |
| } |
|
|
|
|
| |
| |
| |
| from fastapi.staticfiles import StaticFiles |
|
|
| if os.path.exists("build"): |
| app.mount("/", StaticFiles(directory="build", html=True), name="static") |
| print("β
React app mounted at /") |
| else: |
| print("β οΈ Build directory not found, serving API only") |
|
|
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host="0.0.0.0", port=7860, reload=False) |