safiaa02 commited on
Commit
38d7e0c
·
verified ·
1 Parent(s): a34d86a

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +6 -6
model.py CHANGED
@@ -7,11 +7,11 @@ class ChestGPTDemo:
7
  def __init__(self, device=None):
8
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
9
 
10
- # Load clinical GPT-2 model
11
- self.tokenizer = AutoTokenizer.from_pretrained("mrm8488/GPT-2-finetuned-clinical-notes")
12
- self.lm = AutoModelForCausalLM.from_pretrained("mrm8488/GPT-2-finetuned-clinical-notes").to(self.device).eval()
13
 
14
- # Few-shot prompt to guide generation
15
  self.prompt = (
16
  "[radiology] Example 1:\n"
17
  "Global Disease: Cardiomegaly\n"
@@ -23,11 +23,11 @@ class ChestGPTDemo:
23
  )
24
 
25
  def process_image(self, img: Image.Image):
26
- # Placeholder for image features will add ViT later
27
  return None
28
 
29
  def predict(self, img: Image.Image):
30
- _ = self.process_image(img) # skip vision for now
31
  inputs = self.tokenizer(self.prompt, return_tensors="pt", padding=True).to(self.device)
32
  outputs = self.lm.generate(
33
  **inputs,
 
7
  def __init__(self, device=None):
8
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
9
 
10
+ # Tiny GPT-2 model (lightweight and public)
11
+ self.tokenizer = AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2")
12
+ self.lm = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2").to(self.device).eval()
13
 
14
+ # Few-shot prompt to simulate clinical logic
15
  self.prompt = (
16
  "[radiology] Example 1:\n"
17
  "Global Disease: Cardiomegaly\n"
 
23
  )
24
 
25
  def process_image(self, img: Image.Image):
26
+ # Placeholder for image encoding will integrate ViT later
27
  return None
28
 
29
  def predict(self, img: Image.Image):
30
+ _ = self.process_image(img)
31
  inputs = self.tokenizer(self.prompt, return_tensors="pt", padding=True).to(self.device)
32
  outputs = self.lm.generate(
33
  **inputs,