File size: 1,558 Bytes
045c6f2 1c3f102 38d7e0c a34d86a 38d7e0c 045c6f2 a34d86a 045c6f2 38d7e0c a34d86a 045c6f2 38d7e0c a34d86a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from torchvision import transforms
from PIL import Image
class ChestGPTDemo:
def __init__(self, device=None):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
# Tiny GPT-2 model (lightweight and public)
self.tokenizer = AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2")
self.lm = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2").to(self.device).eval()
# Few-shot prompt to simulate clinical logic
self.prompt = (
"[radiology] Example 1:\n"
"Global Disease: Cardiomegaly\n"
"Local Finding: Patchy opacity in right lower lobe (BBox: 50,60,120,150)\n\n"
"[radiology] Example 2:\n"
"Global Disease: Normal\n"
"Local Finding: No abnormalities detected\n\n"
"[radiology] Please describe this chest X-ray. Mention global diseases and local findings if visible.\n"
)
def process_image(self, img: Image.Image):
# Placeholder for image encoding — will integrate ViT later
return None
def predict(self, img: Image.Image):
_ = self.process_image(img)
inputs = self.tokenizer(self.prompt, return_tensors="pt", padding=True).to(self.device)
outputs = self.lm.generate(
**inputs,
max_new_tokens=100,
pad_token_id=self.tokenizer.eos_token_id
)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|