import gradio as gr import os import torch from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor from qwen_vl_utils import process_vision_info import html is_spaces = os.getenv("SPACE_ID") is not None spaces_available = False GPU = None if is_spaces: try: from spaces import GPU spaces_available = True except ImportError: pass def gpu_decorator(func): if spaces_available and GPU is not None: return GPU(func) return func sys_prompt = """First output the the types of degradations in image briefly in tags, and then output what effects do these degradation have on the image in tags, then based on the strength of degradation, output an APPROPRIATE length for the reasoning process in tags, and then summarize the content of reasoning and the give the answer in tags, provides the user with the answer briefly in .""" project_dir = os.path.dirname(os.path.abspath(__file__)) temp_dir = None if not is_spaces: temp_dir = os.path.join(project_dir, ".gradio_temp") os.makedirs(temp_dir, exist_ok=True) os.environ["GRADIO_TEMP_DIR"] = temp_dir MODEL_PATH = os.getenv("MODEL_PATH", "Jiaqi-hkust/Robust-R1-RL") class ModelHandler: def __init__(self, model_path): self.model_path = model_path self.model = None self.processor = None def _load_model(self): if self.model is not None: return try: self.processor = AutoProcessor.from_pretrained(self.model_path) try: cuda_available = torch.cuda.is_available() if cuda_available: torch_dtype = torch.bfloat16 else: torch_dtype = torch.float32 except RuntimeError: torch_dtype = torch.bfloat16 self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( self.model_path, torch_dtype=torch_dtype, device_map="auto", attn_implementation="sdpa", trust_remote_code=True ) except Exception as e: raise e def predict(self, message_dict, history, temperature, max_tokens): if self.model is None: self._load_model() text = message_dict.get("text", "") files = message_dict.get("files", []) messages = [] if history: for msg in history: role = msg.get("role", "") content = msg.get("content", "") if role == "user": user_content = [] if isinstance(content, list): for item in content: if isinstance(item, str): if os.path.exists(item) or any(item.lower().endswith(ext) for ext in ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp']): user_content.append({"type": "image", "image": item}) else: user_content.append({"type": "text", "text": item}) elif isinstance(item, dict): user_content.append(item) elif isinstance(content, str): if content: user_content.append({"type": "text", "text": content}) if user_content: messages.append({"role": "user", "content": user_content}) elif role == "assistant": if isinstance(content, str) and content: messages.append({"role": "assistant", "content": content}) current_content = [] if files: for file_path in files: current_content.append({"type": "image", "image": file_path}) if text: sys_prompt_formatted = " ".join(sys_prompt.split()) full_text = f"{text}\n{sys_prompt_formatted}" current_content.append({"type": "text", "text": full_text}) if current_content: messages.append({"role": "user", "content": current_content}) text_prompt = self.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) inputs = self.processor( text=[text_prompt], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt" ) inputs = inputs.to(self.model.device) generation_kwargs = dict( **inputs, max_new_tokens=max_tokens, temperature=temperature, do_sample=True if temperature > 0 else False, ) try: with torch.no_grad(): generated_ids = self.model.generate(**generation_kwargs) input_length = inputs['input_ids'].shape[1] generated_ids = generated_ids[0][input_length:] generated_text = self.processor.tokenizer.decode( generated_ids, skip_special_tokens=True ) if generated_text and generated_text.strip(): yield generated_text else: warning_msg = "โš ๏ธ No output generated. The model may not have produced any response." yield warning_msg except Exception as e: yield f"โŒ Generation error: {str(e)}" return model_handler = None def get_model_handler(): global model_handler if model_handler is None: model_handler = ModelHandler(MODEL_PATH) return model_handler custom_css = """ .gradio-container { font-family: 'Inter', sans-serif; } #chatbot { height: 650px !important; overflow-y: auto; } """ @gpu_decorator def respond(user_msg, history, temp, tokens): text = user_msg.get("text", "").strip() files = user_msg.get("files", []) user_message = {"role": "user", "content": []} for file_path in files: if file_path: abs_path = os.path.abspath(file_path) if not os.path.isabs(file_path) else file_path user_message["content"].append({"path": abs_path}) if text: user_message["content"].append(text) if not files and text: user_message["content"] = text history.append(user_message) yield history, gr.MultimodalTextbox(value=None, interactive=False) history.append({"role": "assistant", "content": ""}) try: previous_history = history[:-2] if len(history) >= 2 else [] handler = get_model_handler() generated_text = "" for chunk in handler.predict(user_msg, previous_history, temp, tokens): generated_text = chunk safe_text = generated_text.replace("<", "<").replace(">", ">") history[-1]["content"] = safe_text yield history, gr.MultimodalTextbox(interactive=False) except Exception as e: history[-1]["content"] = f"โŒ Inference error: {str(e)}" yield history, gr.MultimodalTextbox(interactive=True) yield history, gr.MultimodalTextbox(value=None, interactive=True) def create_chat_ui(): with gr.Blocks(title="Robust-R1") as demo: with gr.Row(): gr.Markdown("# ๐Ÿค–Robust-R1:Degradation-Aware Reasoning for Robust Visual Understanding") with gr.Row(): with gr.Column(scale=4): chatbot = gr.Chatbot( elem_id="chatbot", label="Chat", avatar_images=(None, "https://api.dicebear.com/7.x/bottts/svg?seed=Qwen"), height=650 ) chat_input = gr.MultimodalTextbox( interactive=True, file_types=["image"], placeholder="Enter your question or upload an image...", show_label=False ) with gr.Column(scale=1): with gr.Group(): gr.Markdown("### โš™๏ธ Generation Config") temperature = gr.Slider( minimum=0.01, maximum=1.0, value=0.6, step=0.05, label="Temperature" ) max_tokens = gr.Slider( minimum=128, maximum=4096, value=1024, step=128, label="Max New Tokens" ) clear_btn = gr.Button("๐Ÿ—‘๏ธ Clear Context", variant="stop") gr.Markdown("---") gr.Markdown("### ๐Ÿ“š Examples") gr.Markdown("Click the examples below to quickly fill the input box and start a conversation") example_images_dir = os.path.join(project_dir, "assets") examples_config = [ ("What type of vehicles are the people riding?\n0. trucks\n1. wagons\n2. jeeps\n3. cars\n", os.path.join(example_images_dir, "1.jpg")), ("What is the giant fish in the air?\n0. blimp\n1. balloon\n2. kite\n3. sculpture\n", os.path.join(example_images_dir, "2.jpg")), ] example_data = [] for text, img_path in examples_config: if os.path.exists(img_path): example_data.append({"text": text, "files": [img_path]}) if example_data: gr.Examples( examples=example_data, inputs=chat_input, label="", examples_per_page=3 ) else: gr.Markdown("*No example images available, please manually upload images for testing*") chat_input.submit( respond, inputs=[chat_input, chatbot, temperature, max_tokens], outputs=[chatbot, chat_input] ) def clear_history(): return [], None clear_btn.click(clear_history, outputs=[chatbot, chat_input]) return demo if __name__ == "__main__": demo = create_chat_ui() if is_spaces: allowed_paths = [project_dir] if project_dir else None demo.launch( theme=gr.themes.Soft(), css=custom_css, show_error=True, allowed_paths=allowed_paths ) else: demo.launch( theme=gr.themes.Soft(), css=custom_css, server_name="0.0.0.0", server_port=7860, share=False, show_error=True, allowed_paths=[project_dir] if project_dir else None )