import gradio as gr import torch from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor from qwen_vl_utils import process_vision_info import os # Model path configuration - can be loaded from environment variable or default path MODEL_PATH = os.getenv("MODEL_PATH", "Jiaqi-hkust/Robust-R1") # Global variables to store model and processor model = None processor = None def load_model(): """Load model and processor""" global model, processor if model is None or processor is None: print(f"Loading model: {MODEL_PATH}") # Load model model = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_PATH, torch_dtype=torch.bfloat16, device_map="auto", ) # Load processor processor = AutoProcessor.from_pretrained(MODEL_PATH) print("Model loaded successfully!") return model, processor def inference(image, question, max_new_tokens=1024, temperature=0.7): """Perform inference""" try: # Ensure model is loaded model, processor = load_model() # Validate multimodal inputs if image is None: return "⚠️ Error: Please upload an image. This is a multimodal model that requires both an image and text input." if not question or question.strip() == "": return "⚠️ Error: Please enter your question. This is a multimodal model that requires both an image and text input." # Build multimodal messages (image + text) messages = [ { "role": "user", "content": [ { "type": "image", "image": image, # Image input }, {"type": "text", "text": question}, # Text input ], } ] # Prepare inputs text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) inputs = processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) # Move inputs to the device where the model is located device = next(model.parameters()).device inputs = inputs.to(device) # Generate response generated_ids = model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=True if temperature > 0 else False, ) # Decode output generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) return output_text[0] except Exception as e: return f"An error occurred: {str(e)}" # Create Gradio interface with gr.Blocks(title="Robust-R1", theme=gr.themes.Soft()) as demo: gr.Markdown( """ ## Citation The following is a BibTeX reference: """ ) with gr.Row(): with gr.Column(scale=1): image_input = gr.Image( type="pil", label="📸 Upload Image (Required)", height=400, info="Upload an image that you want to ask questions about" ) question_input = gr.Textbox( label="💬 Your Question (Required)", placeholder="e.g., Describe the content of this image", lines=3, info="Enter your question about the uploaded image" ) with gr.Row(): max_tokens = gr.Slider( minimum=64, maximum=2048, value=512, step=64, label="Max Generation Length" ) temperature = gr.Slider( minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature" ) submit_btn = gr.Button("Submit", variant="primary", size="lg") clear_btn = gr.Button("Clear", variant="secondary") with gr.Column(scale=1): output = gr.Textbox( label="Model Response", lines=15, interactive=False ) # Examples gr.Examples( examples=[ ["What is the name of the Garage?\n0. polo\n1. imam\n2. leke\n3. akd\nFirst output the the types of degradations in image briefly in tags, and thenoutput what effects do these degradation have on the image in tags, then based on the strength of degradation, output an APPROPRIATE length for the reasoning process in tags, and then sunmmarize the content of reasoning and the give the answer in tags,provides the user with the answer briefly in .i.e., degradation type here \n influence here\n reasoning process here\nsummary here\nfinal answer."], ], inputs=[question_input], label="Example Questions" ) # Bind events submit_btn.click( fn=inference, inputs=[image_input, question_input, max_tokens, temperature], outputs=output ) clear_btn.click( fn=lambda: (None, "", 512, 0.7, ""), outputs=[image_input, question_input, max_tokens, temperature, output] ) # Show message when page loads demo.load( fn=lambda: "Model is loading, please wait...", outputs=output ) if __name__ == "__main__": # When running in Space, Gradio will automatically handle the port demo.launch(server_name="0.0.0.0", server_port=7860, share=False)