LLAVA-API / app.py
Ishgan's picture
Update app.py
0c077d2 verified
import requests
from PIL import Image
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
# Model ID
model_id = "llava-hf/llava-1.5-7b-hf"
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model onto the correct device
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True
).to(device)
# Load processor
processor = AutoProcessor.from_pretrained(model_id)
# Define conversation
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "What are these?"},
{"type": "image"},
],
},
]
# Apply chat template
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
# Load image
image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
raw_image = Image.open(requests.get(image_url, stream=True).raw)
# Preprocess inputs
inputs = processor(images=raw_image, text=prompt, return_tensors='pt')
inputs = {k: v.to(device, torch.float16) for k, v in inputs.items()}
# Generate output
output = model.generate(**inputs, max_new_tokens=200, do_sample=False)
# Decode and print result
print(processor.decode(output[0][2:], skip_special_tokens=True))