Fred808 commited on
Commit
4401a89
·
verified ·
1 Parent(s): a57aa35

Update vision_model.py

Browse files
Files changed (1) hide show
  1. vision_model.py +7 -5
vision_model.py CHANGED
@@ -1,17 +1,19 @@
1
  from transformers import AutoProcessor, AutoModelForVision2Seq
2
  from PIL import Image
3
  import torch
4
-
5
  import os
6
  os.environ["HF_HOME"] = "/app/.cache"
7
 
8
-
9
  model_id = "llava-hf/llava-1.5-7b-hf"
10
  processor = AutoProcessor.from_pretrained(model_id)
11
- model = AutoModelForVision2Seq.from_pretrained(model_id, torch_dtype=torch.float16).cuda()
 
 
 
 
12
 
13
  def describe_image(image_path, prompt="Describe this sculpt and any missing regions."):
14
  image = Image.open(image_path).convert("RGB")
15
- inputs = processor(prompt, image, return_tensors="pt").to("cuda", torch.float16)
16
  output = model.generate(**inputs, max_new_tokens=150)
17
- return processor.decode(output[0], skip_special_tokens=True)
 
1
  from transformers import AutoProcessor, AutoModelForVision2Seq
2
  from PIL import Image
3
  import torch
 
4
  import os
5
  os.environ["HF_HOME"] = "/app/.cache"
6
 
 
7
  model_id = "llava-hf/llava-1.5-7b-hf"
8
  processor = AutoProcessor.from_pretrained(model_id)
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ model = AutoModelForVision2Seq.from_pretrained(
11
+ model_id,
12
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
13
+ ).to(device)
14
 
15
  def describe_image(image_path, prompt="Describe this sculpt and any missing regions."):
16
  image = Image.open(image_path).convert("RGB")
17
+ inputs = processor(prompt, image, return_tensors="pt").to(device, model.dtype)
18
  output = model.generate(**inputs, max_new_tokens=150)
19
+ return processor.decode(output[0], skip_special_tokens=True)