Update model.py
Browse files
model.py
CHANGED
|
@@ -7,11 +7,11 @@ class ChestGPTDemo:
|
|
| 7 |
def __init__(self, device=None):
|
| 8 |
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 9 |
|
| 10 |
-
#
|
| 11 |
-
self.tokenizer = AutoTokenizer.from_pretrained("
|
| 12 |
-
self.lm = AutoModelForCausalLM.from_pretrained("
|
| 13 |
|
| 14 |
-
# Few-shot prompt to
|
| 15 |
self.prompt = (
|
| 16 |
"[radiology] Example 1:\n"
|
| 17 |
"Global Disease: Cardiomegaly\n"
|
|
@@ -23,11 +23,11 @@ class ChestGPTDemo:
|
|
| 23 |
)
|
| 24 |
|
| 25 |
def process_image(self, img: Image.Image):
|
| 26 |
-
# Placeholder for image
|
| 27 |
return None
|
| 28 |
|
| 29 |
def predict(self, img: Image.Image):
|
| 30 |
-
_ = self.process_image(img)
|
| 31 |
inputs = self.tokenizer(self.prompt, return_tensors="pt", padding=True).to(self.device)
|
| 32 |
outputs = self.lm.generate(
|
| 33 |
**inputs,
|
|
|
|
| 7 |
def __init__(self, device=None):
|
| 8 |
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 9 |
|
| 10 |
+
# Tiny GPT-2 model (lightweight and public)
|
| 11 |
+
self.tokenizer = AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2")
|
| 12 |
+
self.lm = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2").to(self.device).eval()
|
| 13 |
|
| 14 |
+
# Few-shot prompt to simulate clinical logic
|
| 15 |
self.prompt = (
|
| 16 |
"[radiology] Example 1:\n"
|
| 17 |
"Global Disease: Cardiomegaly\n"
|
|
|
|
| 23 |
)
|
| 24 |
|
| 25 |
def process_image(self, img: Image.Image):
|
| 26 |
+
# Placeholder for image encoding — will integrate ViT later
|
| 27 |
return None
|
| 28 |
|
| 29 |
def predict(self, img: Image.Image):
|
| 30 |
+
_ = self.process_image(img)
|
| 31 |
inputs = self.tokenizer(self.prompt, return_tensors="pt", padding=True).to(self.device)
|
| 32 |
outputs = self.lm.generate(
|
| 33 |
**inputs,
|