ChestGPT / model.py
safiaa02's picture
Update model.py
38d7e0c verified
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from torchvision import transforms
from PIL import Image
class ChestGPTDemo:
def __init__(self, device=None):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
# Tiny GPT-2 model (lightweight and public)
self.tokenizer = AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2")
self.lm = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2").to(self.device).eval()
# Few-shot prompt to simulate clinical logic
self.prompt = (
"[radiology] Example 1:\n"
"Global Disease: Cardiomegaly\n"
"Local Finding: Patchy opacity in right lower lobe (BBox: 50,60,120,150)\n\n"
"[radiology] Example 2:\n"
"Global Disease: Normal\n"
"Local Finding: No abnormalities detected\n\n"
"[radiology] Please describe this chest X-ray. Mention global diseases and local findings if visible.\n"
)
def process_image(self, img: Image.Image):
# Placeholder for image encoding — will integrate ViT later
return None
def predict(self, img: Image.Image):
_ = self.process_image(img)
inputs = self.tokenizer(self.prompt, return_tensors="pt", padding=True).to(self.device)
outputs = self.lm.generate(
**inputs,
max_new_tokens=100,
pad_token_id=self.tokenizer.eos_token_id
)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)