TESTLLAVA / app.py
EYEDOL's picture
Update app.py
3575d8b verified
import gradio as gr
import torch
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
# Configuration
MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
print(f"Loading {MODEL_ID}...")
# 1. Load Model
# We use bfloat16 (half precision) which is faster than 4-bit for small models
# and fits easily in 16GB or even 8GB VRAM.
try:
model = Qwen2VLForConditionalGeneration.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
)
# The min_pixels and max_pixels arguments help control resolution for speed
processor = AutoProcessor.from_pretrained(MODEL_ID, min_pixels=256*28*28, max_pixels=1280*28*28)
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
print("Ensure you have a GPU available.")
exit()
def chat_response(message, history, image_input):
"""
Main generation function called by Gradio.
"""
if image_input is None:
return "Please upload an image first to chat about it!"
# 2. Prepare the messages for Qwen2-VL
# Qwen expects a specific format: a list of messages with specific 'type' keys
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": image_input, # Pass the PIL image directly
},
{"type": "text", "text": message},
],
}
]
# 3. Process inputs
# qwen_vl_utils helps process the image and text into tensors
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
# Move inputs to the same device as the model
inputs = inputs.to(model.device)
# 4. Generate Response
# We limit max_new_tokens to 200 for speed
generated_ids = model.generate(
**inputs,
max_new_tokens=200,
do_sample=True,
temperature=0.7,
top_p=0.9
)
# 5. Decode output
# We trim the input tokens from the output to get only the new response
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
response = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
return response
# --- Gradio UI Setup ---
with gr.Blocks(title="Qwen2-VL Chat", theme=gr.themes.Soft()) as demo:
gr.Markdown("# ๐Ÿš€ Qwen2-VL-2B: Fast Image Chat")
gr.Markdown("Upload an image and ask questions. This 2B model is significantly faster than LLaVA-7B.")
with gr.Row():
with gr.Column(scale=1):
image_box = gr.Image(type="pil", label="Upload Image")
with gr.Column(scale=2):
chatbot = gr.ChatInterface(
fn=chat_response,
additional_inputs=[image_box],
title="Chat",
description="Ask about the uploaded image.",
examples=[
["What is in this image?", None],
["Describe the lighting.", None],
["Read the text in the image.", None],
],
)
if __name__ == "__main__":
demo.queue().launch()