|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
|
|
|
class ChestGPTDemo: |
|
|
def __init__(self, device=None): |
|
|
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2") |
|
|
self.lm = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2").to(self.device).eval() |
|
|
|
|
|
|
|
|
self.prompt = ( |
|
|
"[radiology] Example 1:\n" |
|
|
"Global Disease: Cardiomegaly\n" |
|
|
"Local Finding: Patchy opacity in right lower lobe (BBox: 50,60,120,150)\n\n" |
|
|
"[radiology] Example 2:\n" |
|
|
"Global Disease: Normal\n" |
|
|
"Local Finding: No abnormalities detected\n\n" |
|
|
"[radiology] Please describe this chest X-ray. Mention global diseases and local findings if visible.\n" |
|
|
) |
|
|
|
|
|
def process_image(self, img: Image.Image): |
|
|
|
|
|
return None |
|
|
|
|
|
def predict(self, img: Image.Image): |
|
|
_ = self.process_image(img) |
|
|
inputs = self.tokenizer(self.prompt, return_tensors="pt", padding=True).to(self.device) |
|
|
outputs = self.lm.generate( |
|
|
**inputs, |
|
|
max_new_tokens=100, |
|
|
pad_token_id=self.tokenizer.eos_token_id |
|
|
) |
|
|
return self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|