Spaces:
Running
on
Zero
Running
on
Zero
| # app.py | |
| # ===================== Imports ===================== | |
| from qwen_vl_utils import fetch_image | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoProcessor, | |
| TextIteratorStreamer, | |
| ) | |
| import gradio as gr | |
| import librosa | |
| import torch | |
| import numpy as np | |
| import soundfile as sf | |
| from threading import Thread | |
| from copy import deepcopy | |
| import os | |
| # NEW: ZeroGPU requirement | |
| import spaces | |
| # ===================== Load Model (lazy move to GPU) ===================== | |
| model_path = "FreedomIntelligence/ShizhenGPT-7B-Omni" | |
| # 先在 CPU 上加载权重;等真正需要推理时再迁移到 GPU(由 @spaces.GPU 管理) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| torch_dtype=torch.bfloat16, # 如遇不支持,可改成 torch.float16 | |
| trust_remote_code=True, | |
| ) | |
| processor = AutoProcessor.from_pretrained( | |
| model_path, | |
| trust_remote_code=True | |
| ) | |
| model.eval() | |
| # 某些权重会把 chat_template 放 tokenizer 上;做个兼容 | |
| if hasattr(processor, "tokenizer") and hasattr(processor.tokenizer, "chat_template"): | |
| processor.chat_template = processor.tokenizer.chat_template | |
| # 标志位:仅在第一次推理时把模型迁移到 GPU | |
| _MODEL_ON_CUDA = False | |
| # ===================== Streaming Generation ===================== | |
| def generate_with_streaming(model, processor, text, images=None, audios=None, history=None): | |
| # Process images | |
| processed_images = None | |
| if images: | |
| text = "".join(["<|vision_start|><|image_pad|><|vision_end|>"] * len(images)) + text | |
| processed_images = [ | |
| fetch_image({"type": "image", "image": img, "max_pixels": 360 * 420}) | |
| for img in images | |
| if img is not None | |
| ] | |
| # Process audios | |
| processed_audios = None | |
| if audios: | |
| text = "".join(["<|audio_bos|><|AUDIO|><|audio_eos|>"] * len(audios)) + text | |
| processed_audios = [audio for audio in audios if audio is not None] | |
| # Build conversation history | |
| messages = [] | |
| if history: | |
| for user_msg, assistant_msg in history: | |
| messages.append({"role": "user", "content": user_msg}) | |
| if assistant_msg: | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| # Clean multimodal tokens from previous history | |
| for m in messages: | |
| m["content"] = m["content"].replace("<|audio_bos|><|AUDIO|><|audio_eos|>", "").replace( | |
| "<|vision_start|><|image_pad|><|vision_end|>", "" | |
| ) | |
| # Add current user input | |
| messages.append({"role": "user", "content": text}) | |
| # Prepare model input | |
| templated = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) or "" | |
| input_data = processor( | |
| text=[templated], | |
| audios=processed_audios, | |
| images=processed_images, | |
| return_tensors="pt", | |
| padding=True, | |
| ) | |
| # Move tensors to the current model device | |
| for k, v in input_data.items(): | |
| if hasattr(v, "to"): | |
| input_data[k] = v.to(model.device) | |
| # Start streaming generation | |
| streamer = TextIteratorStreamer( | |
| processor.tokenizer, skip_special_tokens=True, skip_prompt=True | |
| ) | |
| generation_kwargs = dict( | |
| **input_data, | |
| streamer=streamer, | |
| max_new_tokens=1500, | |
| do_sample=True, | |
| temperature=0.2, | |
| top_p=0.8, | |
| ) | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| # Yield generated tokens in stream | |
| for new_text in streamer: | |
| yield new_text | |
| # ===================== Audio Preprocessing ===================== | |
| def process_audio(audio): | |
| """Convert recorded/loaded audio into required sampling rate.""" | |
| if audio is None: | |
| return None | |
| try: | |
| sr, y = audio | |
| if y.ndim > 1: | |
| y = y[:, 0] # Keep the first channel only | |
| save_path = "./temp.wav" | |
| sf.write(save_path, y, sr) | |
| # 有些处理器没有 feature_extractor;做个兜底 | |
| target_sr = getattr( | |
| getattr(processor, "feature_extractor", None), "sampling_rate", 16000 | |
| ) | |
| y_resampled, _ = librosa.load(save_path, sr=target_sr, mono=True) | |
| return y_resampled | |
| except Exception as e: | |
| print(f"Error processing audio: {e}") | |
| return None | |
| # ===================== Prediction Function for Gradio (ZeroGPU) ===================== | |
| # 关键:让 ZeroGPU 能检测到 GPU 函数,并把一次调用的最长占用设长些 | |
| def predict(message, image, audio, chatbox): | |
| global _MODEL_ON_CUDA, model | |
| # 首次调用被装饰的函数时,ZeroGPU 才真正分配 GPU;此时再迁移模型 | |
| if not _MODEL_ON_CUDA: | |
| model.to("cuda") | |
| _MODEL_ON_CUDA = True | |
| chat_history = deepcopy(chatbox) | |
| processed_audio = [process_audio(audio)] if audio is not None else None | |
| processed_image = [image] if image is not None else None | |
| chatbox.append([message, ""]) | |
| response = "" | |
| # Streaming model response | |
| for chunk in generate_with_streaming(model, processor, message, processed_image, processed_audio, chat_history): | |
| response += chunk | |
| chatbox[-1][1] = response | |
| # 作为生成器返回,Gradio Chatbot 将实时刷新 | |
| yield chatbox | |
| print("\n=== Complete Model Response ===") | |
| print(response) | |
| print("============================\n", flush=True) | |
| return chatbox | |
| # ===================== CSS for UI ===================== | |
| css = """ | |
| .gradio-container { | |
| background-color: #f7f7f7; | |
| font-family: 'Arial', sans-serif; | |
| } | |
| .chat-message { | |
| padding: 15px; | |
| border-radius: 10px; | |
| margin-bottom: 10px; | |
| } | |
| .user-message { | |
| background-color: #e6f7ff; | |
| border-left: 5px solid #1890ff; | |
| } | |
| .bot-message { | |
| background-color: #f2f2f2; | |
| border-left: 5px solid #52c41a; | |
| } | |
| .title { | |
| text-align: center; | |
| color: #1890ff; | |
| font-size: 24px; | |
| margin-bottom: 20px; | |
| } | |
| """ | |
| # ===================== Gradio App ===================== | |
| with gr.Blocks(css=css) as demo: | |
| gr.HTML("<h1 class='title'>ShizhenGPT: Multimodal LLM for Traditional Chinese Medicine (Vision, Audio, Text)</h1>") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| chatbot = gr.Chatbot(height=500) | |
| message = gr.Textbox(label="Your Question", placeholder="Type your question here...") | |
| with gr.Row(): | |
| submit_btn = gr.Button("Submit", variant="primary") | |
| clear_btn = gr.Button("Clear") | |
| with gr.Column(scale=1): | |
| image_input = gr.Image(type="filepath", label="Upload Image") | |
| audio_input = gr.Audio(type="numpy", label="Record or Upload Audio") | |
| # examples_list = [ | |
| # ["段某,男,49岁。平素性急,时而头晕,有高血压史。今日中午突然昏仆,不省人事,牙关紧闭,两手握固,肢体强痉,面赤身热,苔黄腻,脉弦滑而数。请你分析病情,并给出可能的医治方案。", None, None, []], | |
| # ["请分析图中所示舌象,并告诉我哪种舌苔表现与之相符?", "examples/images.JPG", None, []], # 确保该图片路径存在 | |
| # ] | |
| # gr.Examples( | |
| # examples=examples_list, | |
| # inputs=[message, image_input, audio_input, chatbot], # === FIX: 绑定真实 inputs | |
| # label="Examples", | |
| # cache_examples=False, # 可避免在某些环境下的缓存报错 | |
| # ) | |
| # Submit button: run prediction | |
| submit_btn.click( | |
| predict, | |
| inputs=[message, image_input, audio_input, chatbot], | |
| outputs=[chatbot], | |
| show_progress=True | |
| ).then( | |
| lambda: "", # Clear textbox after sending | |
| outputs=[message] | |
| ) | |
| # Clear button: reset inputs and chat history | |
| clear_btn.click( | |
| lambda: (None, None, None, []), | |
| outputs=[message, image_input, audio_input, chatbot] | |
| ) | |
| # ===================== Run App ===================== | |
| if __name__ == "__main__": | |
| # ZeroGPU 下建议限制并发,避免重复申请 GPU | |
| demo.queue().launch( | |
| server_name="0.0.0.0", server_port=7860, share=True | |
| ) | |