Spaces:
Running on Zero
Running on Zero
| import gradio as gr | |
| import os | |
| import torch | |
| from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor | |
| from qwen_vl_utils import process_vision_info | |
| import html | |
| # 导入 spaces 模块用于 GPU 检测 | |
| is_spaces = os.getenv("SPACE_ID") is not None | |
| spaces_available = False | |
| GPU = None | |
| if is_spaces: | |
| try: | |
| from spaces import GPU | |
| spaces_available = True | |
| except ImportError: | |
| print("⚠️ spaces module not available, GPU detection may not work") | |
| # 创建条件装饰器 | |
| def gpu_decorator(func): | |
| """条件应用 GPU 装饰器""" | |
| if spaces_available and GPU is not None: | |
| return GPU(func) | |
| return func | |
| # 条件安装 flash-attn(延迟到模型加载时,避免启动时 CUDA 检查) | |
| # 注意:在 ZeroGPU 环境中,启动时 CUDA 可能还不可用 | |
| # flash-attn 将在模型加载时根据实际 CUDA 可用性决定是否使用 | |
| sys_prompt = """First output the types of degradations in image briefly in <TYPE> <TYPE_END> tags, | |
| and then output what effects do these degradation have on the image in <INFLUENCE> <INFLUENCE_END> tags, | |
| then based on the strength of degradation, output an APPROPRIATE length for the reasoning process in <REASONING> <REASONING_END> tags, | |
| and then summarize the content of reasoning and the give the answer in <CONCLUSION> <CONCLUSION_END> tags, | |
| provides the user with the answer briefly in <ANSWER> <ANSWER_END>.""" | |
| project_dir = os.path.dirname(os.path.abspath(__file__)) | |
| if not is_spaces: | |
| temp_dir = os.path.join(project_dir, ".gradio_temp") | |
| os.makedirs(temp_dir, exist_ok=True) | |
| os.environ["GRADIO_TEMP_DIR"] = temp_dir | |
| MODEL_PATH = os.getenv("MODEL_PATH", "Jiaqi-hkust/Robust-R1-RL") | |
| print(f"==========================================") | |
| print(f"Initializing application...") | |
| print(f"==========================================") | |
| class ModelHandler: | |
| def __init__(self, model_path): | |
| self.model_path = model_path | |
| self.model = None | |
| self.processor = None | |
| self._load_model() | |
| def _load_model(self): | |
| try: | |
| print(f"⏳ Loading model weights, this may take a few minutes...") | |
| self.processor = AutoProcessor.from_pretrained(self.model_path) | |
| if torch.cuda.is_available(): | |
| device_capability = torch.cuda.get_device_capability() | |
| use_flash_attention = device_capability[0] >= 8 | |
| print(f"🔧 CUDA available, device capability: {device_capability}") | |
| else: | |
| use_flash_attention = False | |
| print(f"🔧 Using CPU or non-CUDA device") | |
| self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| self.model_path, | |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto", | |
| # attn_implementation="flash_attention_2" if use_flash_attention else "sdpa", | |
| attn_implementation="sdpa", | |
| trust_remote_code=True | |
| ) | |
| print("✅ Model loaded successfully!") | |
| except Exception as e: | |
| print(f"❌ Model loading failed: {e}") | |
| raise e | |
| def predict(self, messages, temperature, max_tokens): | |
| # 这里的 messages 已经是处理好的标准 OpenAI 格式列表 | |
| # 将 sys_prompt 注入到最后一条用户消息中 | |
| if messages and messages[-1]["role"] == "user": | |
| content = messages[-1]["content"] | |
| sys_prompt_fmt = "\n" + " ".join(sys_prompt.split()) | |
| if isinstance(content, str): | |
| messages[-1]["content"] += sys_prompt_fmt | |
| elif isinstance(content, list): | |
| # 查找文本部分并追加,如果没有则添加 | |
| text_found = False | |
| for item in content: | |
| if item.get("type") == "text": | |
| item["text"] += sys_prompt_fmt | |
| text_found = True | |
| break | |
| if not text_found: | |
| content.append({"type": "text", "text": sys_prompt_fmt}) | |
| text_prompt = self.processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| image_inputs, video_inputs = process_vision_info(messages) | |
| inputs = self.processor( | |
| text=[text_prompt], | |
| images=image_inputs, | |
| videos=video_inputs, | |
| padding=True, | |
| return_tensors="pt" | |
| ) | |
| inputs = inputs.to(self.model.device) | |
| generation_kwargs = dict( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| do_sample=True if temperature > 0 else False, | |
| ) | |
| with torch.no_grad(): | |
| generated_ids = self.model.generate(**generation_kwargs) | |
| input_length = inputs['input_ids'].shape[1] | |
| generated_ids = generated_ids[0][input_length:] | |
| generated_text = self.processor.tokenizer.decode( | |
| generated_ids, skip_special_tokens=True | |
| ) | |
| yield generated_text | |
| model_handler = None | |
| def get_model_handler(): | |
| """Get model handler with lazy loading""" | |
| global model_handler | |
| if model_handler is None: | |
| print("🔄 Initializing model handler...") | |
| model_handler = ModelHandler(MODEL_PATH) | |
| return model_handler | |
| def _history_to_messages(history): | |
| """ | |
| 将 Gradio 的 Tuple 历史 [[user, bot], ...] 转换为 OpenAI 格式的消息列表。 | |
| 以便发送给模型。 | |
| """ | |
| messages = [] | |
| for pair in history: | |
| user_msg, bot_msg = pair | |
| # --- 处理用户消息 --- | |
| if user_msg: | |
| # 检查是否是文件路径(图片) | |
| # Gradio 中图片通常是临时路径,或者 http 链接 | |
| is_image = False | |
| if isinstance(user_msg, str): | |
| if os.path.exists(user_msg) or user_msg.startswith("http"): | |
| # 简单的判断:如果是现有路径或者是URL,且看起来像图片 | |
| lower_msg = user_msg.lower() | |
| if any(lower_msg.endswith(ext) for ext in ['.jpg', '.png', '.jpeg', '.webp', '.bmp']): | |
| is_image = True | |
| # 构建 User Content | |
| if is_image: | |
| # 这是一个独立的图片消息 | |
| # 注意:为了模型效果,最好将图片和紧接着的文本合并。 | |
| # 但为了代码简单,我们先作为独立消息,大多数 VLM 也能处理。 | |
| messages.append({ | |
| "role": "user", | |
| "content": [{"type": "image", "image": user_msg}] | |
| }) | |
| else: | |
| # 这是一个文本消息 | |
| # 如果上一条也是 user 且是 image,尝试合并(可选,这里简单起见直接 append) | |
| messages.append({ | |
| "role": "user", | |
| "content": [{"type": "text", "text": str(user_msg)}] | |
| }) | |
| # --- 处理机器人消息 --- | |
| if bot_msg: | |
| messages.append({ | |
| "role": "assistant", | |
| "content": [{"type": "text", "text": str(bot_msg)}] | |
| }) | |
| return messages | |
| # ========================================== | |
| def respond(user_msg, history, temp, tokens): | |
| """ | |
| user_msg: {'text': '...', 'files': ['...']} | |
| history: List[List[str | None]] <- 这是旧版格式 | |
| """ | |
| files = user_msg.get("files", []) | |
| text = user_msg.get("text", "") | |
| # 1. 更新 UI History (先让用户在界面上看到自己的输入) | |
| # ------------------------------------------------- | |
| # 先处理图片:每张图片作为一个单独的 (path, None) 元组添加到历史 | |
| # Gradio Chatbot 会自动识别路径并显示图片 | |
| for f in files: | |
| history.append([f, None]) | |
| # 再处理文本 | |
| if text: | |
| # 如果有文本,添加 (text, None),准备让机器人回复这一条 | |
| history.append([text, None]) | |
| elif not text and files: | |
| # 如果只有图片没文本,添加一个空的占位符,让机器人回复图片 | |
| # 注意:这里我们用 (None, None) 可能会报错,用 ("(image uploaded)", None) 或者 | |
| # 更常见的做法是:最后一项的 user 部分必须非空才能看起来正常, | |
| # 但我们可以在下面立即生成回复填补 bot 部分。 | |
| pass | |
| # 此时 history 已经更新了用户的输入 | |
| yield history, gr.MultimodalTextbox(value=None, interactive=False) | |
| # 2. 调用模型 | |
| # ------------------------------------------------- | |
| try: | |
| handler = get_model_handler() | |
| # 将 Tuple 历史转换为模型能懂的 Messages 列表 | |
| # 注意:我们需要把刚才加入的“只有用户部分”的消息也转进去 | |
| messages = _history_to_messages(history) | |
| # 在界面上为机器人的回复预留位置 | |
| # 如果最后一条是 [user_content, None],我们把它改成 [user_content, ""] | |
| if history and history[-1][1] is None: | |
| history[-1][1] = "" | |
| else: | |
| # 如果万一没有对应行,加一行 | |
| history.append([None, ""]) | |
| # 流式生成 | |
| full_response = "" | |
| for chunk in handler.predict(messages, temp, tokens): | |
| full_response += chunk | |
| history[-1][1] = full_response # 更新最后一行机器人的回复 | |
| yield history, gr.MultimodalTextbox(interactive=False) | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| err_msg = f"❌ Error: {str(e)}" | |
| if history and history[-1][1] is None: | |
| history[-1][1] = err_msg | |
| else: | |
| history.append([None, err_msg]) | |
| yield history, gr.MultimodalTextbox(interactive=True) | |
| yield history, gr.MultimodalTextbox(interactive=True) | |
| def create_chat_ui(): | |
| custom_css = """ | |
| .gradio-container { font-family: 'Inter', sans-serif; } | |
| #chatbot { height: 650px !important; overflow-y: auto; } | |
| """ | |
| with gr.Blocks(title="Robust-R1") as demo: | |
| with gr.Row(): | |
| gr.Markdown("# 🤖Robust-R1:Degradation-Aware Reasoning for Robust Visual Understanding") | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| chatbot = gr.Chatbot( | |
| elem_id="chatbot", | |
| label="Chat", | |
| avatar_images=(None, "https://api.dicebear.com/7.x/bottts/svg?seed=Qwen"), | |
| height=650, | |
| ) | |
| chat_input = gr.MultimodalTextbox( | |
| interactive=True, | |
| file_types=["image"], | |
| placeholder="Enter your question or upload an image...", | |
| show_label=False | |
| ) | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| gr.Markdown("### ⚙️ Generation Config") | |
| temperature = gr.Slider( | |
| minimum=0.01, maximum=1.0, value=0.6, step=0.05, | |
| label="Temperature" | |
| ) | |
| max_tokens = gr.Slider( | |
| minimum=128, maximum=4096, value=1024, step=128, | |
| label="Max New Tokens" | |
| ) | |
| clear_btn = gr.Button("🗑️ Clear Context", variant="stop") | |
| gr.Markdown("---") | |
| gr.Markdown("### 📚 Examples") | |
| gr.Markdown("Click the examples below to quickly fill the input box and start a conversation") | |
| example_images_dir = os.path.join(project_dir, "assets") | |
| examples_config = [ | |
| ("What type of vehicles are the people riding?\n0. trucks\n1. wagons\n2. jeeps\n3. cars\n", os.path.join(example_images_dir, "1.jpg")), | |
| ("What is the giant fish in the air?\n0. blimp\n1. balloon\n2. kite\n3. sculpture\n", os.path.join(example_images_dir, "2.jpg")), | |
| ] | |
| example_data = [] | |
| for text, img_path in examples_config: | |
| if os.path.exists(img_path): | |
| example_data.append({"text": text, "files": [img_path]}) | |
| if example_data: | |
| gr.Examples( | |
| examples=example_data, | |
| inputs=chat_input, | |
| label="", | |
| examples_per_page=3 | |
| ) | |
| else: | |
| gr.Markdown("*No example images available, please manually upload images for testing*") | |
| chat_input.submit( | |
| respond, | |
| inputs=[chat_input, chatbot, temperature, max_tokens], | |
| outputs=[chatbot, chat_input] | |
| ) | |
| def clear_history(): return [], None | |
| clear_btn.click(clear_history, outputs=[chatbot, chat_input]) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_chat_ui() | |
| custom_css = """ | |
| .gradio-container { font-family: 'Inter', sans-serif; } | |
| #chatbot { height: 650px !important; overflow-y: auto; } | |
| """ | |
| if is_spaces: | |
| print(f"🚀 Running on Hugging Face Spaces: {os.getenv('SPACE_ID')}") | |
| demo.launch( | |
| theme=gr.themes.Soft(), | |
| css=custom_css, | |
| show_error=True, | |
| allowed_paths=[project_dir] if project_dir else None | |
| ) | |
| else: | |
| print(f"🚀 Service is starting, please visit: http://localhost:7860") | |
| demo.launch( | |
| theme=gr.themes.Soft(), | |
| css=custom_css, | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True, | |
| allowed_paths=[project_dir] | |
| ) | |