Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import os | |
| import torch | |
| from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor | |
| from qwen_vl_utils import process_vision_info | |
| import html | |
| is_spaces = os.getenv("SPACE_ID") is not None | |
| spaces_available = False | |
| GPU = None | |
| if is_spaces: | |
| try: | |
| from spaces import GPU | |
| spaces_available = True | |
| except ImportError: | |
| pass | |
| def gpu_decorator(func): | |
| if spaces_available and GPU is not None: | |
| return GPU(func) | |
| return func | |
| sys_prompt = """First output the the types of degradations in image briefly in <TYPE> <TYPE_END> tags, | |
| and then output what effects do these degradation have on the image in <INFLUENCE> <INFLUENCE_END> tags, | |
| then based on the strength of degradation, output an APPROPRIATE length for the reasoning process in <REASONING> <REASONING_END> tags, | |
| and then summarize the content of reasoning and the give the answer in <CONCLUSION> <CONCLUSION_END> tags, | |
| provides the user with the answer briefly in <ANSWER> <ANSWER_END>.""" | |
| project_dir = os.path.dirname(os.path.abspath(__file__)) | |
| temp_dir = None | |
| if not is_spaces: | |
| temp_dir = os.path.join(project_dir, ".gradio_temp") | |
| os.makedirs(temp_dir, exist_ok=True) | |
| os.environ["GRADIO_TEMP_DIR"] = temp_dir | |
| MODEL_PATH = os.getenv("MODEL_PATH", "Jiaqi-hkust/Robust-R1-RL") | |
| class ModelHandler: | |
| def __init__(self, model_path): | |
| self.model_path = model_path | |
| self.model = None | |
| self.processor = None | |
| def _load_model(self): | |
| if self.model is not None: | |
| return | |
| try: | |
| self.processor = AutoProcessor.from_pretrained(self.model_path) | |
| try: | |
| cuda_available = torch.cuda.is_available() | |
| if cuda_available: | |
| torch_dtype = torch.bfloat16 | |
| else: | |
| torch_dtype = torch.float32 | |
| except RuntimeError: | |
| torch_dtype = torch.bfloat16 | |
| self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| self.model_path, | |
| torch_dtype=torch_dtype, | |
| device_map="auto", | |
| attn_implementation="sdpa", | |
| trust_remote_code=True | |
| ) | |
| except Exception as e: | |
| raise e | |
| def predict(self, message_dict, history, temperature, max_tokens): | |
| if self.model is None: | |
| self._load_model() | |
| text = message_dict.get("text", "") | |
| files = message_dict.get("files", []) | |
| messages = [] | |
| if history: | |
| for msg in history: | |
| role = msg.get("role", "") | |
| content = msg.get("content", "") | |
| if role == "user": | |
| user_content = [] | |
| if isinstance(content, list): | |
| for item in content: | |
| if isinstance(item, str): | |
| if os.path.exists(item) or any(item.lower().endswith(ext) for ext in ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp']): | |
| user_content.append({"type": "image", "image": item}) | |
| else: | |
| user_content.append({"type": "text", "text": item}) | |
| elif isinstance(item, dict): | |
| user_content.append(item) | |
| elif isinstance(content, str): | |
| if content: | |
| user_content.append({"type": "text", "text": content}) | |
| if user_content: | |
| messages.append({"role": "user", "content": user_content}) | |
| elif role == "assistant": | |
| if isinstance(content, str) and content: | |
| messages.append({"role": "assistant", "content": content}) | |
| current_content = [] | |
| if files: | |
| for file_path in files: | |
| current_content.append({"type": "image", "image": file_path}) | |
| if text: | |
| sys_prompt_formatted = " ".join(sys_prompt.split()) | |
| full_text = f"{text}\n{sys_prompt_formatted}" | |
| current_content.append({"type": "text", "text": full_text}) | |
| if current_content: | |
| messages.append({"role": "user", "content": current_content}) | |
| text_prompt = self.processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| image_inputs, video_inputs = process_vision_info(messages) | |
| inputs = self.processor( | |
| text=[text_prompt], | |
| images=image_inputs, | |
| videos=video_inputs, | |
| padding=True, | |
| return_tensors="pt" | |
| ) | |
| inputs = inputs.to(self.model.device) | |
| generation_kwargs = dict( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| do_sample=True if temperature > 0 else False, | |
| ) | |
| try: | |
| with torch.no_grad(): | |
| generated_ids = self.model.generate(**generation_kwargs) | |
| input_length = inputs['input_ids'].shape[1] | |
| generated_ids = generated_ids[0][input_length:] | |
| generated_text = self.processor.tokenizer.decode( | |
| generated_ids, | |
| skip_special_tokens=True | |
| ) | |
| if generated_text and generated_text.strip(): | |
| yield generated_text | |
| else: | |
| warning_msg = "⚠️ No output generated. The model may not have produced any response." | |
| yield warning_msg | |
| except Exception as e: | |
| yield f"❌ Generation error: {str(e)}" | |
| return | |
| model_handler = None | |
| def get_model_handler(): | |
| global model_handler | |
| if model_handler is None: | |
| model_handler = ModelHandler(MODEL_PATH) | |
| return model_handler | |
| custom_css = """ | |
| .gradio-container { font-family: 'Inter', sans-serif; } | |
| #chatbot { height: 650px !important; overflow-y: auto; } | |
| """ | |
| def respond(user_msg, history, temp, tokens): | |
| text = user_msg.get("text", "").strip() | |
| files = user_msg.get("files", []) | |
| user_message = {"role": "user", "content": []} | |
| for file_path in files: | |
| if file_path: | |
| abs_path = os.path.abspath(file_path) if not os.path.isabs(file_path) else file_path | |
| user_message["content"].append({"path": abs_path}) | |
| if text: | |
| user_message["content"].append(text) | |
| if not files and text: | |
| user_message["content"] = text | |
| history.append(user_message) | |
| yield history, gr.MultimodalTextbox(value=None, interactive=False) | |
| history.append({"role": "assistant", "content": ""}) | |
| try: | |
| previous_history = history[:-2] if len(history) >= 2 else [] | |
| handler = get_model_handler() | |
| generated_text = "" | |
| for chunk in handler.predict(user_msg, previous_history, temp, tokens): | |
| generated_text = chunk | |
| safe_text = generated_text.replace("<", "<").replace(">", ">") | |
| history[-1]["content"] = safe_text | |
| yield history, gr.MultimodalTextbox(interactive=False) | |
| except Exception as e: | |
| history[-1]["content"] = f"❌ Inference error: {str(e)}" | |
| yield history, gr.MultimodalTextbox(interactive=True) | |
| yield history, gr.MultimodalTextbox(value=None, interactive=True) | |
| def create_chat_ui(): | |
| with gr.Blocks(title="Robust-R1") as demo: | |
| with gr.Row(): | |
| gr.Markdown("# 🤖Robust-R1:Degradation-Aware Reasoning for Robust Visual Understanding") | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| chatbot = gr.Chatbot( | |
| elem_id="chatbot", | |
| label="Chat", | |
| avatar_images=(None, "https://api.dicebear.com/7.x/bottts/svg?seed=Qwen"), | |
| height=650 | |
| ) | |
| chat_input = gr.MultimodalTextbox( | |
| interactive=True, | |
| file_types=["image"], | |
| placeholder="Enter your question or upload an image...", | |
| show_label=False | |
| ) | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| gr.Markdown("### ⚙️ Generation Config") | |
| temperature = gr.Slider( | |
| minimum=0.01, maximum=1.0, value=0.6, step=0.05, | |
| label="Temperature" | |
| ) | |
| max_tokens = gr.Slider( | |
| minimum=128, maximum=4096, value=1024, step=128, | |
| label="Max New Tokens" | |
| ) | |
| clear_btn = gr.Button("🗑️ Clear Context", variant="stop") | |
| gr.Markdown("---") | |
| gr.Markdown("### 📚 Examples") | |
| gr.Markdown("Click the examples below to quickly fill the input box and start a conversation") | |
| example_images_dir = os.path.join(project_dir, "assets") | |
| examples_config = [ | |
| ("What type of vehicles are the people riding?\n0. trucks\n1. wagons\n2. jeeps\n3. cars\n", os.path.join(example_images_dir, "1.jpg")), | |
| ("What is the giant fish in the air?\n0. blimp\n1. balloon\n2. kite\n3. sculpture\n", os.path.join(example_images_dir, "2.jpg")), | |
| ] | |
| example_data = [] | |
| for text, img_path in examples_config: | |
| if os.path.exists(img_path): | |
| example_data.append({"text": text, "files": [img_path]}) | |
| if example_data: | |
| gr.Examples( | |
| examples=example_data, | |
| inputs=chat_input, | |
| label="", | |
| examples_per_page=3 | |
| ) | |
| else: | |
| gr.Markdown("*No example images available, please manually upload images for testing*") | |
| chat_input.submit( | |
| respond, | |
| inputs=[chat_input, chatbot, temperature, max_tokens], | |
| outputs=[chatbot, chat_input] | |
| ) | |
| def clear_history(): return [], None | |
| clear_btn.click(clear_history, outputs=[chatbot, chat_input]) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_chat_ui() | |
| if is_spaces: | |
| allowed_paths = [project_dir] if project_dir else None | |
| demo.launch( | |
| theme=gr.themes.Soft(), | |
| css=custom_css, | |
| show_error=True, | |
| allowed_paths=allowed_paths | |
| ) | |
| else: | |
| demo.launch( | |
| theme=gr.themes.Soft(), | |
| css=custom_css, | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True, | |
| allowed_paths=[project_dir] if project_dir else None | |
| ) |