Neemah commited on
Commit
08aa07f
·
verified ·
1 Parent(s): dae598b

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +5 -2
model.py CHANGED
@@ -14,16 +14,19 @@ processor = AutoProcessor.from_pretrained(MODEL_ID)
14
 
15
  use_cuda = torch.cuda.is_available()
16
  dtype = torch.bfloat16 if use_cuda else torch.float32
 
17
 
18
  model = AutoModelForImageTextToText.from_pretrained(
19
  MODEL_ID,
20
- dtype=dtype,
21
- device_map=0 if use_cuda else -1
22
  )
23
 
 
24
  model.eval()
25
 
26
  print("MedGemma loaded successfully!")
 
27
 
28
  def generate_report(image):
29
  """
 
14
 
15
  use_cuda = torch.cuda.is_available()
16
  dtype = torch.bfloat16 if use_cuda else torch.float32
17
+ device = "cuda:0" if use_cuda else "cpu"
18
 
19
  model = AutoModelForImageTextToText.from_pretrained(
20
  MODEL_ID,
21
+ torch_dtype=dtype,
22
+ device_map=device
23
  )
24
 
25
+ model.generation_config.pad_token_id = processor.tokenizer.eos_token_id
26
  model.eval()
27
 
28
  print("MedGemma loaded successfully!")
29
+ print(f"MedGemma loaded on: {device}")
30
 
31
  def generate_report(image):
32
  """