import torch from transformers import AutoTokenizer, AutoModelForCausalLM from torchvision import transforms from PIL import Image class ChestGPTDemo: def __init__(self, device=None): self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") # Tiny GPT-2 model (lightweight and public) self.tokenizer = AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2") self.lm = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2").to(self.device).eval() # Few-shot prompt to simulate clinical logic self.prompt = ( "[radiology] Example 1:\n" "Global Disease: Cardiomegaly\n" "Local Finding: Patchy opacity in right lower lobe (BBox: 50,60,120,150)\n\n" "[radiology] Example 2:\n" "Global Disease: Normal\n" "Local Finding: No abnormalities detected\n\n" "[radiology] Please describe this chest X-ray. Mention global diseases and local findings if visible.\n" ) def process_image(self, img: Image.Image): # Placeholder for image encoding — will integrate ViT later return None def predict(self, img: Image.Image): _ = self.process_image(img) inputs = self.tokenizer(self.prompt, return_tensors="pt", padding=True).to(self.device) outputs = self.lm.generate( **inputs, max_new_tokens=100, pad_token_id=self.tokenizer.eos_token_id ) return self.tokenizer.decode(outputs[0], skip_special_tokens=True)