File size: 1,558 Bytes
045c6f2
 
 
 
 
 
1c3f102
 
 
38d7e0c
 
 
a34d86a
38d7e0c
045c6f2
a34d86a
 
 
 
 
 
 
045c6f2
 
 
38d7e0c
a34d86a
045c6f2
 
38d7e0c
a34d86a
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from torchvision import transforms
from PIL import Image

class ChestGPTDemo:
    def __init__(self, device=None):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")

        # Tiny GPT-2 model (lightweight and public)
        self.tokenizer = AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2")
        self.lm = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2").to(self.device).eval()

        # Few-shot prompt to simulate clinical logic
        self.prompt = (
            "[radiology] Example 1:\n"
            "Global Disease: Cardiomegaly\n"
            "Local Finding: Patchy opacity in right lower lobe (BBox: 50,60,120,150)\n\n"
            "[radiology] Example 2:\n"
            "Global Disease: Normal\n"
            "Local Finding: No abnormalities detected\n\n"
            "[radiology] Please describe this chest X-ray. Mention global diseases and local findings if visible.\n"
        )

    def process_image(self, img: Image.Image):
        # Placeholder for image encoding — will integrate ViT later
        return None

    def predict(self, img: Image.Image):
        _ = self.process_image(img)
        inputs = self.tokenizer(self.prompt, return_tensors="pt", padding=True).to(self.device)
        outputs = self.lm.generate(
            **inputs,
            max_new_tokens=100,
            pad_token_id=self.tokenizer.eos_token_id
        )
        return self.tokenizer.decode(outputs[0], skip_special_tokens=True)