Wonder-Griffin commited on
Commit
9ab1828
Β·
verified Β·
1 Parent(s): afd30b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -22
app.py CHANGED
@@ -1,40 +1,188 @@
1
- import os, threading
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch
3
  import gradio as gr
4
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  MODEL_ID = os.environ.get("MODEL_ID", "Wonder-Griffin/ZeusMM-SFT-oasst1")
7
- HF_TOKEN = os.environ.get("HF_TOKEN")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- # Avoid Accelerate mapped-device heuristics that can create meta tensors on CPU
10
- os.environ.setdefault("ACCELERATE_DISABLE_MAPPED_DEVICE", "1")
11
- os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
12
 
13
- # --- Load tokenizer ---
14
  tok_kwargs = {"trust_remote_code": True}
15
- if HF_TOKEN: tok_kwargs["token"] = HF_TOKEN
 
 
 
 
16
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, **tok_kwargs)
17
 
18
- # --- Load model (CPU-safe / GPU-smart) ---
19
- IS_GPU = torch.cuda.is_available()
20
  if IS_GPU:
21
- # GPU: allow device_map and auto dtype, but force eager attention
22
  mdl_kwargs = dict(
23
  trust_remote_code=True,
24
  torch_dtype="auto",
25
  device_map="auto",
26
- attn_implementation="eager",
27
  )
28
- if HF_TOKEN: mdl_kwargs["token"] = HF_TOKEN
 
 
 
 
29
  model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **mdl_kwargs)
30
  else:
31
- # CPU: NO device_map, NO low_cpu_mem_usage -> real tensors (not meta)
32
- mdl_kwargs = dict(
33
- trust_remote_code=True,
34
- torch_dtype=torch.float32,
35
- low_cpu_mem_usage=False,
36
- attn_implementation="eager",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  )
38
- if HF_TOKEN: mdl_kwargs["token"] = HF_TOKEN
39
- model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **mdl_kwargs)
40
- model.to("cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # ---
3
+ # title: ZeusMM Chat
4
+ # emoji: πŸ€–
5
+ # colorFrom: indigo
6
+ # colorTo: purple
7
+ # sdk: gradio
8
+ # sdk_version: 5.0.1
9
+ # app_file: app.py
10
+ # pinned: false
11
+ # ---
12
+
13
+ import os
14
+ import threading
15
  import torch
16
  import gradio as gr
17
+
18
+ from transformers import (
19
+ AutoTokenizer,
20
+ AutoModelForCausalLM,
21
+ AutoConfig,
22
+ TextIteratorStreamer,
23
+ )
24
+
25
+ from huggingface_hub import hf_hub_download
26
+ from safetensors.torch import load_file
27
+
28
+ # ===== Env & Model config =====
29
+ os.environ.setdefault("ACCELERATE_DISABLE_MAPPED_DEVICE", "1") # avoid meta-tensors on CPU
30
+ os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") # faster downloads in Spaces
31
 
32
  MODEL_ID = os.environ.get("MODEL_ID", "Wonder-Griffin/ZeusMM-SFT-oasst1")
33
+ HF_TOKEN = os.environ.get("HF_TOKEN") # add as a Space secret if the model is private
34
+ IS_GPU = torch.cuda.is_available()
35
+
36
+ # Optional: pin to a specific revision to avoid surprise code updates
37
+ MODEL_REVISION = os.environ.get("MODEL_REVISION") # e.g., a commit SHA; leave unset to use latest
38
+
39
+
40
+ # ===== Robust CPU loader: builds real tensors, no meta, then loads weights =====
41
+ def load_cpu_no_meta(model_id: str, hf_token: str | None = None, revision: str | None = None):
42
+ cfg = AutoConfig.from_pretrained(
43
+ model_id,
44
+ trust_remote_code=True,
45
+ token=hf_token,
46
+ revision=revision,
47
+ )
48
+ model = AutoModelForCausalLM.from_config(
49
+ cfg,
50
+ trust_remote_code=True,
51
+ torch_dtype=torch.float32,
52
+ )
53
+ # Allocate real storage on CPU for all params/buffers
54
+ model.to_empty(device="cpu")
55
+
56
+ # Find and load the primary weight file
57
+ # (adjust filename if your repo uses something else)
58
+ weights_path = hf_hub_download(
59
+ repo_id=model_id,
60
+ filename="model.safetensors",
61
+ token=hf_token,
62
+ revision=revision,
63
+ )
64
+ state = load_file(weights_path) # safetensors -> state_dict
65
+
66
+ missing, unexpected = model.load_state_dict(state, strict=False)
67
+ if missing or unexpected:
68
+ # Print to Space logs; non-fatal if they are non-critical heads/keys
69
+ print("Missing keys:", missing)
70
+ print("Unexpected keys:", unexpected)
71
+
72
+ model.eval()
73
+ return model
74
 
 
 
 
75
 
76
+ # ===== Tokenizer (shared) =====
77
  tok_kwargs = {"trust_remote_code": True}
78
+ if HF_TOKEN:
79
+ tok_kwargs["token"] = HF_TOKEN
80
+ if MODEL_REVISION:
81
+ tok_kwargs["revision"] = MODEL_REVISION
82
+
83
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, **tok_kwargs)
84
 
85
+
86
+ # ===== Model (GPU uses device_map, CPU uses robust loader) =====
87
  if IS_GPU:
 
88
  mdl_kwargs = dict(
89
  trust_remote_code=True,
90
  torch_dtype="auto",
91
  device_map="auto",
92
+ attn_implementation="eager", # stable across kernels
93
  )
94
+ if HF_TOKEN:
95
+ mdl_kwargs["token"] = HF_TOKEN
96
+ if MODEL_REVISION:
97
+ mdl_kwargs["revision"] = MODEL_REVISION
98
+
99
  model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **mdl_kwargs)
100
  else:
101
+ model = load_cpu_no_meta(MODEL_ID, HF_TOKEN, MODEL_REVISION)
102
+
103
+
104
+ # ===== Prompt building =====
105
+ def build_prompt(system_message: str, history: list[tuple[str, str]], user_message: str) -> str:
106
+ messages = []
107
+ if system_message:
108
+ messages.append({"role": "system", "content": system_message})
109
+
110
+ for u, a in (history or []):
111
+ if u:
112
+ messages.append({"role": "user", "content": u})
113
+ if a:
114
+ messages.append({"role": "assistant", "content": a})
115
+
116
+ messages.append({"role": "user", "content": user_message})
117
+
118
+ if hasattr(tokenizer, "apply_chat_template"):
119
+ try:
120
+ return tokenizer.apply_chat_template(
121
+ messages, tokenize=False, add_generation_prompt=True
122
+ )
123
+ except Exception:
124
+ pass
125
+
126
+ # Fallback (generic)
127
+ out = []
128
+ if system_message:
129
+ out.append(f"[SYSTEM] {system_message}\n")
130
+ for m in messages:
131
+ role = (m.get("role") or "user").upper()
132
+ out.append(f"[{role}] {m.get('content','')}\n")
133
+ out.append("[ASSISTANT] ")
134
+ return "".join(out)
135
+
136
+
137
+ # ===== Generation (streaming) =====
138
+ def respond(message, history, system_message, max_tokens, temperature, top_p):
139
+ prompt = build_prompt(system_message, history, message)
140
+ inputs = tokenizer(prompt, return_tensors="pt")
141
+
142
+ # Send inputs to the same device as the first model parameter (works for CPU/GPU)
143
+ first_param_device = next(model.parameters()).device
144
+ inputs = {k: v.to(first_param_device) for k, v in inputs.items()}
145
+
146
+ streamer = TextIteratorStreamer(
147
+ tokenizer,
148
+ skip_prompt=True,
149
+ skip_special_tokens=True,
150
  )
151
+
152
+ gen_kwargs = dict(
153
+ **inputs,
154
+ max_new_tokens=int(max_tokens),
155
+ temperature=float(temperature),
156
+ top_p=float(top_p),
157
+ do_sample=True,
158
+ streamer=streamer,
159
+ )
160
+
161
+ t = threading.Thread(target=model.generate, kwargs=gen_kwargs)
162
+ t.start()
163
+
164
+ partial = ""
165
+ for chunk in streamer:
166
+ partial += chunk
167
+ yield partial
168
+
169
+
170
+ # ===== UI =====
171
+ demo = gr.ChatInterface(
172
+ fn=respond,
173
+ additional_inputs=[
174
+ gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
175
+ gr.Slider(minimum=1, maximum=4096, value=512, step=1, label="Max new tokens"),
176
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
177
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
178
+ ],
179
+ title="ZeusMM Chat",
180
+ description="Chat with your ZeusMM-SFT model with streaming responses.",
181
+ )
182
+
183
+ # Expose for Spaces
184
+ app = demo
185
+
186
+ if __name__ == "__main__":
187
+ # queue helps avoid cold-start timeouts and enables token streaming
188
+ demo.queue(max_size=32, concurrency_count=1).launch(server_name="0.0.0.0", server_port=7860)