Fred808 commited on
Commit
8f1ad47
·
verified ·
1 Parent(s): 386ac2a

Update vision_model.py

Browse files
Files changed (1) hide show
  1. vision_model.py +9 -5
vision_model.py CHANGED
@@ -1,16 +1,20 @@
1
- from transformers import BlipProcessor, Blip2ForConditionalGeneration
2
  from PIL import Image
3
  import torch
4
  import os
 
5
  os.environ["HF_HOME"] = "/app/.cache"
6
 
7
- model_id = "Salesforce/blip2-opt-2.7b"
8
- processor = BlipProcessor.from_pretrained(model_id)
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
- model = Blip2ForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32).to(device)
 
 
 
11
 
12
  def describe_image(image_path, prompt="Describe this sculpt and any missing regions."):
13
  image = Image.open(image_path).convert("RGB")
14
- inputs = processor(image, prompt, return_tensors="pt").to(device, model.dtype)
15
  output = model.generate(**inputs, max_new_tokens=150)
16
  return processor.decode(output[0], skip_special_tokens=True)
 
1
+ from transformers import AutoProcessor, LlavaForConditionalGeneration
2
  from PIL import Image
3
  import torch
4
  import os
5
+
6
  os.environ["HF_HOME"] = "/app/.cache"
7
 
8
+ model_id = "llava-hf/llava-1.5-7b-hf"
9
+ processor = AutoProcessor.from_pretrained(model_id)
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ model = LlavaForConditionalGeneration.from_pretrained(
12
+ model_id,
13
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
14
+ ).to(device)
15
 
16
  def describe_image(image_path, prompt="Describe this sculpt and any missing regions."):
17
  image = Image.open(image_path).convert("RGB")
18
+ inputs = processor(prompt, image, return_tensors="pt").to(device, model.dtype)
19
  output = model.generate(**inputs, max_new_tokens=150)
20
  return processor.decode(output[0], skip_special_tokens=True)