Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor | |
| from qwen_vl_utils import process_vision_info | |
| import os | |
| # Model path configuration - can be loaded from environment variable or default path | |
| MODEL_PATH = os.getenv("MODEL_PATH", "Jiaqi-hkust/Robust-R1") | |
| # Global variables to store model and processor | |
| model = None | |
| processor = None | |
| def load_model(): | |
| """Load model and processor""" | |
| global model, processor | |
| if model is None or processor is None: | |
| print(f"Loading model: {MODEL_PATH}") | |
| # Load model | |
| model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| MODEL_PATH, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| ) | |
| # Load processor | |
| processor = AutoProcessor.from_pretrained(MODEL_PATH) | |
| print("Model loaded successfully!") | |
| return model, processor | |
| def inference(image, question, max_new_tokens=1024, temperature=0.7): | |
| """Perform inference""" | |
| try: | |
| # Ensure model is loaded | |
| model, processor = load_model() | |
| # Validate multimodal inputs | |
| if image is None: | |
| return "⚠️ Error: Please upload an image. This is a multimodal model that requires both an image and text input." | |
| if not question or question.strip() == "": | |
| return "⚠️ Error: Please enter your question. This is a multimodal model that requires both an image and text input." | |
| # Build multimodal messages (image + text) | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "image", | |
| "image": image, # Image input | |
| }, | |
| {"type": "text", "text": question}, # Text input | |
| ], | |
| } | |
| ] | |
| # Prepare inputs | |
| text = processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| image_inputs, video_inputs = process_vision_info(messages) | |
| inputs = processor( | |
| text=[text], | |
| images=image_inputs, | |
| videos=video_inputs, | |
| padding=True, | |
| return_tensors="pt", | |
| ) | |
| # Move inputs to the device where the model is located | |
| device = next(model.parameters()).device | |
| inputs = inputs.to(device) | |
| # Generate response | |
| generated_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| do_sample=True if temperature > 0 else False, | |
| ) | |
| # Decode output | |
| generated_ids_trimmed = [ | |
| out_ids[len(in_ids):] | |
| for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
| ] | |
| output_text = processor.batch_decode( | |
| generated_ids_trimmed, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False | |
| ) | |
| return output_text[0] | |
| except Exception as e: | |
| return f"An error occurred: {str(e)}" | |
| # Create Gradio interface | |
| with gr.Blocks(title="Robust-R1", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| ## Citation | |
| The following is a BibTeX reference: | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image( | |
| type="pil", | |
| label="📸 Upload Image (Required)", | |
| height=400, | |
| info="Upload an image that you want to ask questions about" | |
| ) | |
| question_input = gr.Textbox( | |
| label="💬 Your Question (Required)", | |
| placeholder="e.g., Describe the content of this image", | |
| lines=3, | |
| info="Enter your question about the uploaded image" | |
| ) | |
| with gr.Row(): | |
| max_tokens = gr.Slider( | |
| minimum=64, | |
| maximum=2048, | |
| value=512, | |
| step=64, | |
| label="Max Generation Length" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature" | |
| ) | |
| submit_btn = gr.Button("Submit", variant="primary", size="lg") | |
| clear_btn = gr.Button("Clear", variant="secondary") | |
| with gr.Column(scale=1): | |
| output = gr.Textbox( | |
| label="Model Response", | |
| lines=15, | |
| interactive=False | |
| ) | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| ["What is the name of the Garage?\n0. polo\n1. imam\n2. leke\n3. akd\nFirst output the the types of degradations in image briefly in <TYPE> <TYPE_END> tags, and thenoutput what effects do these degradation have on the image in <INFLUENCE> <INFLUENCE_END> tags, then based on the strength of degradation, output an APPROPRIATE length for the reasoning process in <REASONING> <REASONING_END>tags, and then sunmmarize the content of reasoning and the give the answer in <CONCLUSION> <CONCLUSION_END>tags,provides the user with the answer briefly in<ANSWER> <ANSWER_END>.i.e., <TYPE> degradation type here <TYPE_END>\n<INFLUENCE> influence here<INFLUENCE_END>\n<REASONING> reasoning process here<REASONING_END>\n<CONCLUSION>summary here<CONCLUSION_END>\n<ANSWER>final answer<ANSWER_END>."], | |
| ], | |
| inputs=[question_input], | |
| label="Example Questions" | |
| ) | |
| # Bind events | |
| submit_btn.click( | |
| fn=inference, | |
| inputs=[image_input, question_input, max_tokens, temperature], | |
| outputs=output | |
| ) | |
| clear_btn.click( | |
| fn=lambda: (None, "", 512, 0.7, ""), | |
| outputs=[image_input, question_input, max_tokens, temperature, output] | |
| ) | |
| # Show message when page loads | |
| demo.load( | |
| fn=lambda: "Model is loading, please wait...", | |
| outputs=output | |
| ) | |
| if __name__ == "__main__": | |
| # When running in Space, Gradio will automatically handle the port | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=False) | |