import gradio as gr import torch from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor from qwen_vl_utils import process_vision_info import os # Model path configuration - can be loaded from environment variable or default path MODEL_PATH = os.getenv("MODEL_PATH", "Jiaqi-hkust/Robust-R1") # Global variables to store model and processor model = None processor = None def load_model(): """Load model and processor""" global model, processor if model is None or processor is None: print(f"Loading model: {MODEL_PATH}") # Load model model = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_PATH, torch_dtype=torch.bfloat16, device_map="auto", ) # Load processor processor = AutoProcessor.from_pretrained(MODEL_PATH) print("Model loaded successfully!") return model, processor def inference(image, question, max_new_tokens=1024, temperature=0.7): """Perform inference""" try: # Ensure model is loaded model, processor = load_model() # Validate multimodal inputs if image is None: return "⚠️ Error: Please upload an image. This is a multimodal model that requires both an image and text input." if not question or question.strip() == "": return "⚠️ Error: Please enter your question. This is a multimodal model that requires both an image and text input." # Build multimodal messages (image + text) messages = [ { "role": "user", "content": [ { "type": "image", "image": image, # Image input }, {"type": "text", "text": question}, # Text input ], } ] # Prepare inputs text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) inputs = processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) # Move inputs to the device where the model is located device = next(model.parameters()).device inputs = inputs.to(device) # Generate response generated_ids = model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=True if temperature > 0 else False, ) # Decode output generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) return output_text[0] except Exception as e: return f"An error occurred: {str(e)}" # Create Gradio interface with gr.Blocks(title="Robust-R1: Visual Understanding Demo", theme=gr.themes.Soft()) as demo: gr.Markdown( """ ## Citation The following is a BibTeX reference: """ ) with gr.Row(): with gr.Column(scale=1): image_input = gr.Image( type="pil", label="📸 Upload Image (Required)", height=400, info="Upload an image that you want to ask questions about" ) question_input = gr.Textbox( label="💬 Your Question (Required)", placeholder="e.g., Describe the content of this image", lines=3, info="Enter your question about the uploaded image" ) with gr.Row(): max_tokens = gr.Slider( minimum=64, maximum=2048, value=512, step=64, label="Max Generation Length" ) temperature = gr.Slider( minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature" ) submit_btn = gr.Button("Submit", variant="primary", size="lg") clear_btn = gr.Button("Clear", variant="secondary") with gr.Column(scale=1): output = gr.Textbox( label="Model Response", lines=15, interactive=False ) # Examples gr.Examples( examples=[ ["Describe this image", "What does this image show?"], ], inputs=[question_input], label="Example Questions" ) # Bind events submit_btn.click( fn=inference, inputs=[image_input, question_input, max_tokens, temperature], outputs=output ) clear_btn.click( fn=lambda: (None, "", 512, 0.7, ""), outputs=[image_input, question_input, max_tokens, temperature, output] ) # Show message when page loads demo.load( fn=lambda: "Model is loading, please wait...", outputs=output ) if __name__ == "__main__": # When running in Space, Gradio will automatically handle the port demo.launch(server_name="0.0.0.0", server_port=7860, share=False)