Fred808 commited on
Commit
53db934
·
verified ·
1 Parent(s): 8f1ad47

Update vision_model.py

Browse files
Files changed (1) hide show
  1. vision_model.py +8 -11
vision_model.py CHANGED
@@ -1,20 +1,17 @@
1
- from transformers import AutoProcessor, LlavaForConditionalGeneration
2
  from PIL import Image
3
  import torch
4
  import os
5
 
6
  os.environ["HF_HOME"] = "/app/.cache"
7
 
8
- model_id = "llava-hf/llava-1.5-7b-hf"
9
- processor = AutoProcessor.from_pretrained(model_id)
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
- model = LlavaForConditionalGeneration.from_pretrained(
12
- model_id,
13
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
14
- ).to(device)
15
 
16
- def describe_image(image_path, prompt="Describe this sculpt and any missing regions."):
17
  image = Image.open(image_path).convert("RGB")
18
- inputs = processor(prompt, image, return_tensors="pt").to(device, model.dtype)
19
- output = model.generate(**inputs, max_new_tokens=150)
20
- return processor.decode(output[0], skip_special_tokens=True)
 
1
+ from transformers import BlipProcessor, BlipForConditionalGeneration
2
  from PIL import Image
3
  import torch
4
  import os
5
 
6
  os.environ["HF_HOME"] = "/app/.cache"
7
 
8
+ model_id = "Salesforce/blip-image-captioning-large"
9
+ processor = BlipProcessor.from_pretrained(model_id)
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ model = BlipForConditionalGeneration.from_pretrained(model_id).to(device)
 
 
 
12
 
13
+ def describe_image(image_path, prompt="Describe this image."):
14
  image = Image.open(image_path).convert("RGB")
15
+ inputs = processor(image, prompt, return_tensors="pt").to(device)
16
+ output = model.generate(**inputs, max_new_tokens=100)
17
+ return processor.decode(output[0], skip_special_tokens=True)