import gradio as gr import torch from transformers import Qwen2VLForConditionalGeneration, AutoProcessor from qwen_vl_utils import process_vision_info # Configuration MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct" print(f"Loading {MODEL_ID}...") # 1. Load Model # We use bfloat16 (half precision) which is faster than 4-bit for small models # and fits easily in 16GB or even 8GB VRAM. try: model = Qwen2VLForConditionalGeneration.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto", ) # The min_pixels and max_pixels arguments help control resolution for speed processor = AutoProcessor.from_pretrained(MODEL_ID, min_pixels=256*28*28, max_pixels=1280*28*28) print("Model loaded successfully!") except Exception as e: print(f"Error loading model: {e}") print("Ensure you have a GPU available.") exit() def chat_response(message, history, image_input): """ Main generation function called by Gradio. """ if image_input is None: return "Please upload an image first to chat about it!" # 2. Prepare the messages for Qwen2-VL # Qwen expects a specific format: a list of messages with specific 'type' keys messages = [ { "role": "user", "content": [ { "type": "image", "image": image_input, # Pass the PIL image directly }, {"type": "text", "text": message}, ], } ] # 3. Process inputs # qwen_vl_utils helps process the image and text into tensors text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) inputs = processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) # Move inputs to the same device as the model inputs = inputs.to(model.device) # 4. Generate Response # We limit max_new_tokens to 200 for speed generated_ids = model.generate( **inputs, max_new_tokens=200, do_sample=True, temperature=0.7, top_p=0.9 ) # 5. Decode output # We trim the input tokens from the output to get only the new response generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] response = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] return response # --- Gradio UI Setup --- with gr.Blocks(title="Qwen2-VL Chat", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🚀 Qwen2-VL-2B: Fast Image Chat") gr.Markdown("Upload an image and ask questions. This 2B model is significantly faster than LLaVA-7B.") with gr.Row(): with gr.Column(scale=1): image_box = gr.Image(type="pil", label="Upload Image") with gr.Column(scale=2): chatbot = gr.ChatInterface( fn=chat_response, additional_inputs=[image_box], title="Chat", description="Ask about the uploaded image.", examples=[ ["What is in this image?", None], ["Describe the lighting.", None], ["Read the text in the image.", None], ], ) if __name__ == "__main__": demo.queue().launch()