Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # --- | |
| # title: ZeusMM Chat | |
| # emoji: π€ | |
| # colorFrom: indigo | |
| # colorTo: purple | |
| # sdk: gradio | |
| # sdk_version: 5.0.1 | |
| # app_file: app.py | |
| # pinned: false | |
| # --- | |
| import os | |
| import threading | |
| import torch | |
| import gradio as gr | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| AutoConfig, | |
| TextIteratorStreamer, | |
| ) | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file | |
| # ===== Env & Model config ===== | |
| os.environ.setdefault("ACCELERATE_DISABLE_MAPPED_DEVICE", "1") # avoid meta-tensors on CPU | |
| os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") # faster downloads in Spaces | |
| MODEL_ID = os.environ.get("MODEL_ID", "Wonder-Griffin/ZeusMM-SFT-oasst1") | |
| HF_TOKEN = os.environ.get("HF_TOKEN") # add as a Space secret if the model is private | |
| IS_GPU = torch.cuda.is_available() | |
| # Optional: pin to a specific revision to avoid surprise code updates | |
| MODEL_REVISION = os.environ.get("MODEL_REVISION") # e.g., a commit SHA; leave unset to use latest | |
| # ===== Robust CPU loader: builds real tensors, no meta, then loads weights ===== | |
| def load_cpu_no_meta(model_id: str, hf_token: str | None = None, revision: str | None = None): | |
| cfg = AutoConfig.from_pretrained( | |
| model_id, | |
| trust_remote_code=True, | |
| token=hf_token, | |
| revision=revision, | |
| ) | |
| model = AutoModelForCausalLM.from_config( | |
| cfg, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float32, | |
| ) | |
| # Allocate real storage on CPU for all params/buffers | |
| model.to_empty(device="cpu") | |
| # Find and load the primary weight file | |
| # (adjust filename if your repo uses something else) | |
| weights_path = hf_hub_download( | |
| repo_id=model_id, | |
| filename="model.safetensors", | |
| token=hf_token, | |
| revision=revision, | |
| ) | |
| state = load_file(weights_path) # safetensors -> state_dict | |
| missing, unexpected = model.load_state_dict(state, strict=False) | |
| if missing or unexpected: | |
| # Print to Space logs; non-fatal if they are non-critical heads/keys | |
| print("Missing keys:", missing) | |
| print("Unexpected keys:", unexpected) | |
| model.eval() | |
| return model | |
| # ===== Tokenizer (shared) ===== | |
| tok_kwargs = {"trust_remote_code": True} | |
| if HF_TOKEN: | |
| tok_kwargs["token"] = HF_TOKEN | |
| if MODEL_REVISION: | |
| tok_kwargs["revision"] = MODEL_REVISION | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, **tok_kwargs) | |
| # ===== Model (GPU uses device_map, CPU uses robust loader) ===== | |
| if IS_GPU: | |
| mdl_kwargs = dict( | |
| trust_remote_code=True, | |
| torch_dtype="auto", | |
| device_map="auto", | |
| attn_implementation="eager", # stable across kernels | |
| ) | |
| if HF_TOKEN: | |
| mdl_kwargs["token"] = HF_TOKEN | |
| if MODEL_REVISION: | |
| mdl_kwargs["revision"] = MODEL_REVISION | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **mdl_kwargs) | |
| else: | |
| model = load_cpu_no_meta(MODEL_ID, HF_TOKEN, MODEL_REVISION) | |
| # ===== Prompt building ===== | |
| def build_prompt(system_message: str, history: list[tuple[str, str]], user_message: str) -> str: | |
| messages = [] | |
| if system_message: | |
| messages.append({"role": "system", "content": system_message}) | |
| for u, a in (history or []): | |
| if u: | |
| messages.append({"role": "user", "content": u}) | |
| if a: | |
| messages.append({"role": "assistant", "content": a}) | |
| messages.append({"role": "user", "content": user_message}) | |
| if hasattr(tokenizer, "apply_chat_template"): | |
| try: | |
| return tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| except Exception: | |
| pass | |
| # Fallback (generic) | |
| out = [] | |
| if system_message: | |
| out.append(f"[SYSTEM] {system_message}\n") | |
| for m in messages: | |
| role = (m.get("role") or "user").upper() | |
| out.append(f"[{role}] {m.get('content','')}\n") | |
| out.append("[ASSISTANT] ") | |
| return "".join(out) | |
| # ===== Generation (streaming) ===== | |
| def respond(message, history, system_message, max_tokens, temperature, top_p): | |
| prompt = build_prompt(system_message, history, message) | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| # Send inputs to the same device as the first model parameter (works for CPU/GPU) | |
| first_param_device = next(model.parameters()).device | |
| inputs = {k: v.to(first_param_device) for k, v in inputs.items()} | |
| streamer = TextIteratorStreamer( | |
| tokenizer, | |
| skip_prompt=True, | |
| skip_special_tokens=True, | |
| ) | |
| gen_kwargs = dict( | |
| **inputs, | |
| max_new_tokens=int(max_tokens), | |
| temperature=float(temperature), | |
| top_p=float(top_p), | |
| do_sample=True, | |
| streamer=streamer, | |
| ) | |
| t = threading.Thread(target=model.generate, kwargs=gen_kwargs) | |
| t.start() | |
| partial = "" | |
| for chunk in streamer: | |
| partial += chunk | |
| yield partial | |
| # ===== UI ===== | |
| demo = gr.ChatInterface( | |
| fn=respond, | |
| additional_inputs=[ | |
| gr.Textbox(value="You are a friendly Chatbot.", label="System message"), | |
| gr.Slider(minimum=1, maximum=4096, value=512, step=1, label="Max new tokens"), | |
| gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), | |
| gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"), | |
| ], | |
| title="ZeusMM Chat", | |
| description="Chat with your ZeusMM-SFT model with streaming responses.", | |
| ) | |
| # Expose for Spaces | |
| app = demo | |
| if __name__ == "__main__": | |
| # queue helps avoid cold-start timeouts and enables token streaming | |
| demo.queue(max_size=32, concurrency_count=1).launch(server_name="0.0.0.0", server_port=7860) |