Stuffs / app.py
Wonder-Griffin's picture
Update app.py
9ab1828 verified
#!/usr/bin/env python3
# ---
# title: ZeusMM Chat
# emoji: πŸ€–
# colorFrom: indigo
# colorTo: purple
# sdk: gradio
# sdk_version: 5.0.1
# app_file: app.py
# pinned: false
# ---
import os
import threading
import torch
import gradio as gr
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
AutoConfig,
TextIteratorStreamer,
)
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
# ===== Env & Model config =====
os.environ.setdefault("ACCELERATE_DISABLE_MAPPED_DEVICE", "1") # avoid meta-tensors on CPU
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") # faster downloads in Spaces
MODEL_ID = os.environ.get("MODEL_ID", "Wonder-Griffin/ZeusMM-SFT-oasst1")
HF_TOKEN = os.environ.get("HF_TOKEN") # add as a Space secret if the model is private
IS_GPU = torch.cuda.is_available()
# Optional: pin to a specific revision to avoid surprise code updates
MODEL_REVISION = os.environ.get("MODEL_REVISION") # e.g., a commit SHA; leave unset to use latest
# ===== Robust CPU loader: builds real tensors, no meta, then loads weights =====
def load_cpu_no_meta(model_id: str, hf_token: str | None = None, revision: str | None = None):
cfg = AutoConfig.from_pretrained(
model_id,
trust_remote_code=True,
token=hf_token,
revision=revision,
)
model = AutoModelForCausalLM.from_config(
cfg,
trust_remote_code=True,
torch_dtype=torch.float32,
)
# Allocate real storage on CPU for all params/buffers
model.to_empty(device="cpu")
# Find and load the primary weight file
# (adjust filename if your repo uses something else)
weights_path = hf_hub_download(
repo_id=model_id,
filename="model.safetensors",
token=hf_token,
revision=revision,
)
state = load_file(weights_path) # safetensors -> state_dict
missing, unexpected = model.load_state_dict(state, strict=False)
if missing or unexpected:
# Print to Space logs; non-fatal if they are non-critical heads/keys
print("Missing keys:", missing)
print("Unexpected keys:", unexpected)
model.eval()
return model
# ===== Tokenizer (shared) =====
tok_kwargs = {"trust_remote_code": True}
if HF_TOKEN:
tok_kwargs["token"] = HF_TOKEN
if MODEL_REVISION:
tok_kwargs["revision"] = MODEL_REVISION
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, **tok_kwargs)
# ===== Model (GPU uses device_map, CPU uses robust loader) =====
if IS_GPU:
mdl_kwargs = dict(
trust_remote_code=True,
torch_dtype="auto",
device_map="auto",
attn_implementation="eager", # stable across kernels
)
if HF_TOKEN:
mdl_kwargs["token"] = HF_TOKEN
if MODEL_REVISION:
mdl_kwargs["revision"] = MODEL_REVISION
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **mdl_kwargs)
else:
model = load_cpu_no_meta(MODEL_ID, HF_TOKEN, MODEL_REVISION)
# ===== Prompt building =====
def build_prompt(system_message: str, history: list[tuple[str, str]], user_message: str) -> str:
messages = []
if system_message:
messages.append({"role": "system", "content": system_message})
for u, a in (history or []):
if u:
messages.append({"role": "user", "content": u})
if a:
messages.append({"role": "assistant", "content": a})
messages.append({"role": "user", "content": user_message})
if hasattr(tokenizer, "apply_chat_template"):
try:
return tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
except Exception:
pass
# Fallback (generic)
out = []
if system_message:
out.append(f"[SYSTEM] {system_message}\n")
for m in messages:
role = (m.get("role") or "user").upper()
out.append(f"[{role}] {m.get('content','')}\n")
out.append("[ASSISTANT] ")
return "".join(out)
# ===== Generation (streaming) =====
def respond(message, history, system_message, max_tokens, temperature, top_p):
prompt = build_prompt(system_message, history, message)
inputs = tokenizer(prompt, return_tensors="pt")
# Send inputs to the same device as the first model parameter (works for CPU/GPU)
first_param_device = next(model.parameters()).device
inputs = {k: v.to(first_param_device) for k, v in inputs.items()}
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
gen_kwargs = dict(
**inputs,
max_new_tokens=int(max_tokens),
temperature=float(temperature),
top_p=float(top_p),
do_sample=True,
streamer=streamer,
)
t = threading.Thread(target=model.generate, kwargs=gen_kwargs)
t.start()
partial = ""
for chunk in streamer:
partial += chunk
yield partial
# ===== UI =====
demo = gr.ChatInterface(
fn=respond,
additional_inputs=[
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
gr.Slider(minimum=1, maximum=4096, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
],
title="ZeusMM Chat",
description="Chat with your ZeusMM-SFT model with streaming responses.",
)
# Expose for Spaces
app = demo
if __name__ == "__main__":
# queue helps avoid cold-start timeouts and enables token streaming
demo.queue(max_size=32, concurrency_count=1).launch(server_name="0.0.0.0", server_port=7860)