File size: 1,299 Bytes
24665ce
8c1f0af
457ec39
 
 
0c077d2
 
ba92118
0c077d2
 
ba92118
0c077d2
457ec39
 
 
0c077d2
 
ba92118
0c077d2
457ec39
 
0c077d2
457ec39
 
0c077d2
 
 
 
457ec39
 
 
0c077d2
 
457ec39
8c1f0af
0c077d2
 
 
8c1f0af
0c077d2
 
 
 
 
457ec39
0c077d2
 
457ec39
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import requests
from PIL import Image
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration

# Model ID
model_id = "llava-hf/llava-1.5-7b-hf"

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load model onto the correct device
model = LlavaForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype=torch.float16, 
    low_cpu_mem_usage=True
).to(device)

# Load processor
processor = AutoProcessor.from_pretrained(model_id)

# Define conversation
conversation = [
    {
        "role": "user",
        "content": [
            {"type": "text", "text": "What are these?"},
            {"type": "image"},
        ],
    },
]

# Apply chat template
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)

# Load image
image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
raw_image = Image.open(requests.get(image_url, stream=True).raw)

# Preprocess inputs
inputs = processor(images=raw_image, text=prompt, return_tensors='pt')
inputs = {k: v.to(device, torch.float16) for k, v in inputs.items()}

# Generate output
output = model.generate(**inputs, max_new_tokens=200, do_sample=False)

# Decode and print result
print(processor.decode(output[0][2:], skip_special_tokens=True))