#!/usr/bin/env python3 # --- # title: ZeusMM Chat # emoji: 🤖 # colorFrom: indigo # colorTo: purple # sdk: gradio # sdk_version: 5.0.1 # app_file: app.py # pinned: false # --- import os import threading import torch import gradio as gr from transformers import ( AutoTokenizer, AutoModelForCausalLM, AutoConfig, TextIteratorStreamer, ) from huggingface_hub import hf_hub_download from safetensors.torch import load_file # ===== Env & Model config ===== os.environ.setdefault("ACCELERATE_DISABLE_MAPPED_DEVICE", "1") # avoid meta-tensors on CPU os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") # faster downloads in Spaces MODEL_ID = os.environ.get("MODEL_ID", "Wonder-Griffin/ZeusMM-SFT-oasst1") HF_TOKEN = os.environ.get("HF_TOKEN") # add as a Space secret if the model is private IS_GPU = torch.cuda.is_available() # Optional: pin to a specific revision to avoid surprise code updates MODEL_REVISION = os.environ.get("MODEL_REVISION") # e.g., a commit SHA; leave unset to use latest # ===== Robust CPU loader: builds real tensors, no meta, then loads weights ===== def load_cpu_no_meta(model_id: str, hf_token: str | None = None, revision: str | None = None): cfg = AutoConfig.from_pretrained( model_id, trust_remote_code=True, token=hf_token, revision=revision, ) model = AutoModelForCausalLM.from_config( cfg, trust_remote_code=True, torch_dtype=torch.float32, ) # Allocate real storage on CPU for all params/buffers model.to_empty(device="cpu") # Find and load the primary weight file # (adjust filename if your repo uses something else) weights_path = hf_hub_download( repo_id=model_id, filename="model.safetensors", token=hf_token, revision=revision, ) state = load_file(weights_path) # safetensors -> state_dict missing, unexpected = model.load_state_dict(state, strict=False) if missing or unexpected: # Print to Space logs; non-fatal if they are non-critical heads/keys print("Missing keys:", missing) print("Unexpected keys:", unexpected) model.eval() return model # ===== Tokenizer (shared) ===== tok_kwargs = {"trust_remote_code": True} if HF_TOKEN: tok_kwargs["token"] = HF_TOKEN if MODEL_REVISION: tok_kwargs["revision"] = MODEL_REVISION tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, **tok_kwargs) # ===== Model (GPU uses device_map, CPU uses robust loader) ===== if IS_GPU: mdl_kwargs = dict( trust_remote_code=True, torch_dtype="auto", device_map="auto", attn_implementation="eager", # stable across kernels ) if HF_TOKEN: mdl_kwargs["token"] = HF_TOKEN if MODEL_REVISION: mdl_kwargs["revision"] = MODEL_REVISION model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **mdl_kwargs) else: model = load_cpu_no_meta(MODEL_ID, HF_TOKEN, MODEL_REVISION) # ===== Prompt building ===== def build_prompt(system_message: str, history: list[tuple[str, str]], user_message: str) -> str: messages = [] if system_message: messages.append({"role": "system", "content": system_message}) for u, a in (history or []): if u: messages.append({"role": "user", "content": u}) if a: messages.append({"role": "assistant", "content": a}) messages.append({"role": "user", "content": user_message}) if hasattr(tokenizer, "apply_chat_template"): try: return tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) except Exception: pass # Fallback (generic) out = [] if system_message: out.append(f"[SYSTEM] {system_message}\n") for m in messages: role = (m.get("role") or "user").upper() out.append(f"[{role}] {m.get('content','')}\n") out.append("[ASSISTANT] ") return "".join(out) # ===== Generation (streaming) ===== def respond(message, history, system_message, max_tokens, temperature, top_p): prompt = build_prompt(system_message, history, message) inputs = tokenizer(prompt, return_tensors="pt") # Send inputs to the same device as the first model parameter (works for CPU/GPU) first_param_device = next(model.parameters()).device inputs = {k: v.to(first_param_device) for k, v in inputs.items()} streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True, ) gen_kwargs = dict( **inputs, max_new_tokens=int(max_tokens), temperature=float(temperature), top_p=float(top_p), do_sample=True, streamer=streamer, ) t = threading.Thread(target=model.generate, kwargs=gen_kwargs) t.start() partial = "" for chunk in streamer: partial += chunk yield partial # ===== UI ===== demo = gr.ChatInterface( fn=respond, additional_inputs=[ gr.Textbox(value="You are a friendly Chatbot.", label="System message"), gr.Slider(minimum=1, maximum=4096, value=512, step=1, label="Max new tokens"), gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"), ], title="ZeusMM Chat", description="Chat with your ZeusMM-SFT model with streaming responses.", ) # Expose for Spaces app = demo if __name__ == "__main__": # queue helps avoid cold-start timeouts and enables token streaming demo.queue(max_size=32, concurrency_count=1).launch(server_name="0.0.0.0", server_port=7860)