staeiou commited on
Commit
28ddeb8
·
verified ·
1 Parent(s): e906781

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +314 -0
app.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ # Hugging Face Spaces (Gradio) app that:
3
+ # 1) Loads a Transformers CausalLM from a Hub repo (prefers .safetensors)
4
+ # 2) Runs a fixed list of prompts one-by-one (WITHOUT the "Q:" prefix)
5
+ # 3) Saves the Q/A pairs into examples.md in the requested format
6
+ #
7
+ # Configure via Space Variables/Secrets (recommended):
8
+ # - MODEL_REPO_ID: e.g. "username/my-model-repo"
9
+ # - REVISION: optional (branch/tag/commit)
10
+ # - HF_TOKEN: optional if repo is private
11
+ # - MAX_NEW_TOKENS: optional (default 128)
12
+ #
13
+ # Notes:
14
+ # - This expects the repo to be Transformers-compatible (config/tokenizer present).
15
+ # - If your repo has multiple weight shards, Transformers will pick them up automatically.
16
+ # - The generated examples.md is written to the Space's local filesystem and offered for download.
17
+
18
+ import os
19
+ import time
20
+ from dataclasses import dataclass
21
+ from typing import List, Tuple, Optional
22
+
23
+ import torch
24
+ import gradio as gr
25
+ from huggingface_hub import snapshot_download
26
+ from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
27
+
28
+
29
+ # -----------------------------
30
+ # Prompts (sent WITHOUT "Q:")
31
+ # -----------------------------
32
+ RAW_PROMPTS: List[str] = [
33
+ "What is the capital of France?",
34
+ "Calculate 2+2",
35
+ "chocolate cake recipe",
36
+ "What model are you?",
37
+ "a;lkj2l1;j2r';13",
38
+ "¿Cuántos libros había en la Biblioteca de Alejandría?",
39
+ "How many books were in the library of Alexandria?",
40
+ "Te amo, mi amor. ¿Me amas? ¿Soy tu amor?",
41
+ "My love, I love you. Do you love me? Am I your love?",
42
+ "اردو بولنے والے کے طور پر کام کریں۔",
43
+ "Act as an Urdu speaker.",
44
+ ]
45
+
46
+
47
+ @dataclass
48
+ class LoadSettings:
49
+ repo_id: str
50
+ revision: Optional[str] = None
51
+ hf_token: Optional[str] = None
52
+ torch_dtype: Optional[torch.dtype] = None
53
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
54
+
55
+
56
+ def _env_int(name: str, default: int) -> int:
57
+ try:
58
+ return int(os.getenv(name, default))
59
+ except Exception:
60
+ return default
61
+
62
+
63
+ MAX_NEW_TOKENS_DEFAULT = _env_int("MAX_NEW_TOKENS", 128)
64
+
65
+
66
+ # -----------------------------
67
+ # Model loading
68
+ # -----------------------------
69
+ def load_model_and_tokenizer(settings: LoadSettings):
70
+ if not settings.repo_id or settings.repo_id.strip() == "":
71
+ raise ValueError("MODEL_REPO_ID is empty. Set it in Space variables or type it in the UI.")
72
+
73
+ # Download repo snapshot locally (fast subsequent runs due to caching)
74
+ local_dir = snapshot_download(
75
+ repo_id=settings.repo_id,
76
+ revision=settings.revision,
77
+ token=settings.hf_token,
78
+ local_dir=None,
79
+ local_dir_use_symlinks=False,
80
+ )
81
+
82
+ # Try to pick an appropriate dtype
83
+ if settings.torch_dtype is None:
84
+ if torch.cuda.is_available():
85
+ # bfloat16 is great on modern GPUs; fall back to float16 otherwise
86
+ settings.torch_dtype = torch.bfloat16 if torch.cuda.get_device_capability(0)[0] >= 8 else torch.float16
87
+ else:
88
+ settings.torch_dtype = torch.float32
89
+
90
+ # Load tokenizer/config
91
+ config = AutoConfig.from_pretrained(local_dir)
92
+ tokenizer = AutoTokenizer.from_pretrained(local_dir, use_fast=True)
93
+
94
+ # Ensure pad token exists for generation if needed
95
+ if tokenizer.pad_token is None:
96
+ # Common safe fallback for causal LMs
97
+ tokenizer.pad_token = tokenizer.eos_token
98
+
99
+ # Load model (Transformers will prefer safetensors if present)
100
+ # device_map="auto" works well on GPU; on CPU it can be omitted.
101
+ if torch.cuda.is_available():
102
+ model = AutoModelForCausalLM.from_pretrained(
103
+ local_dir,
104
+ config=config,
105
+ torch_dtype=settings.torch_dtype,
106
+ device_map="auto",
107
+ low_cpu_mem_usage=True,
108
+ use_safetensors=True,
109
+ )
110
+ else:
111
+ model = AutoModelForCausalLM.from_pretrained(
112
+ local_dir,
113
+ config=config,
114
+ torch_dtype=settings.torch_dtype,
115
+ low_cpu_mem_usage=True,
116
+ use_safetensors=True,
117
+ ).to(settings.device)
118
+
119
+ model.eval()
120
+ return model, tokenizer, local_dir
121
+
122
+
123
+ # -----------------------------
124
+ # Prompt formatting + generation
125
+ # -----------------------------
126
+ def build_inputs(tokenizer, prompt: str, device: str):
127
+ # If the tokenizer supports a chat template, use it.
128
+ if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
129
+ messages = [{"role": "user", "content": prompt}]
130
+ input_ids = tokenizer.apply_chat_template(
131
+ messages,
132
+ add_generation_prompt=True,
133
+ return_tensors="pt",
134
+ )
135
+ return input_ids.to(device)
136
+ # Plain text
137
+ enc = tokenizer(prompt, return_tensors="pt")
138
+ return enc["input_ids"].to(device)
139
+
140
+
141
+ @torch.inference_mode()
142
+ def generate_one(
143
+ model,
144
+ tokenizer,
145
+ prompt: str,
146
+ max_new_tokens: int = 128,
147
+ temperature: float = 0.0,
148
+ ) -> str:
149
+ device = next(model.parameters()).device
150
+ input_ids = build_inputs(tokenizer, prompt, device)
151
+
152
+ # Deterministic by default: do_sample=False when temperature == 0
153
+ do_sample = temperature is not None and temperature > 0
154
+
155
+ outputs = model.generate(
156
+ input_ids=input_ids,
157
+ max_new_tokens=max_new_tokens,
158
+ do_sample=do_sample,
159
+ temperature=temperature if do_sample else None,
160
+ top_p=0.95 if do_sample else None,
161
+ pad_token_id=tokenizer.pad_token_id,
162
+ eos_token_id=tokenizer.eos_token_id,
163
+ )
164
+
165
+ # Decode only the newly generated tokens (cleanest "answer")
166
+ gen_ids = outputs[0, input_ids.shape[-1] :]
167
+ text = tokenizer.decode(gen_ids, skip_special_tokens=True)
168
+ return text.strip()
169
+
170
+
171
+ def format_examples_md(pairs: List[Tuple[str, str]]) -> str:
172
+ blocks = []
173
+ for q, a in pairs:
174
+ blocks.append(f"- Q: {q}\n- A: {a}".strip())
175
+ return "\n\n".join(blocks) + "\n"
176
+
177
+
178
+ # -----------------------------
179
+ # Gradio app logic
180
+ # -----------------------------
181
+ MODEL = None
182
+ TOKENIZER = None
183
+ MODEL_LOCAL_DIR = None
184
+
185
+
186
+ def do_load(repo_id: str, revision: str, hf_token: str, max_new_tokens: int):
187
+ global MODEL, TOKENIZER, MODEL_LOCAL_DIR
188
+
189
+ repo_id = (repo_id or "").strip()
190
+ revision = (revision or "").strip() or None
191
+ hf_token = (hf_token or "").strip() or os.getenv("HF_TOKEN") or None
192
+
193
+ settings = LoadSettings(repo_id=repo_id, revision=revision, hf_token=hf_token)
194
+
195
+ MODEL, TOKENIZER, MODEL_LOCAL_DIR = load_model_and_tokenizer(settings)
196
+
197
+ info = [
198
+ f"Loaded repo: `{repo_id}`",
199
+ f"Revision: `{revision or 'default'}`",
200
+ f"Local snapshot dir: `{MODEL_LOCAL_DIR}`",
201
+ f"Device: `{next(MODEL.parameters()).device}`",
202
+ f"Default max_new_tokens: `{max_new_tokens}`",
203
+ ]
204
+ return "\n".join(info)
205
+
206
+
207
+ def generate_examples(max_new_tokens: int, temperature: float):
208
+ if MODEL is None or TOKENIZER is None:
209
+ raise RuntimeError("Model not loaded. Click 'Load model' first (or set MODEL_REPO_ID and restart).")
210
+
211
+ pairs = []
212
+ for p in RAW_PROMPTS:
213
+ ans = generate_one(
214
+ MODEL,
215
+ TOKENIZER,
216
+ p, # sent WITHOUT "Q:"
217
+ max_new_tokens=max_new_tokens,
218
+ temperature=temperature,
219
+ )
220
+ # Keep answers single-line-ish for markdown readability (optional)
221
+ ans_clean = " ".join(ans.splitlines()).strip()
222
+ pairs.append((p, ans_clean))
223
+
224
+ md = format_examples_md(pairs)
225
+
226
+ # Write examples.md
227
+ out_path = os.path.abspath("examples.md")
228
+ with open(out_path, "w", encoding="utf-8") as f:
229
+ f.write(md)
230
+
231
+ return md, out_path
232
+
233
+
234
+ def maybe_autoload():
235
+ """If MODEL_REPO_ID is set, load automatically on startup."""
236
+ repo_id = (os.getenv("MODEL_REPO_ID") or "").strip()
237
+ if not repo_id:
238
+ return "MODEL_REPO_ID not set. Enter a repo id and click 'Load model'."
239
+
240
+ revision = (os.getenv("REVISION") or "").strip() or None
241
+ hf_token = (os.getenv("HF_TOKEN") or "").strip() or None
242
+ max_new_tokens = _env_int("MAX_NEW_TOKENS", MAX_NEW_TOKENS_DEFAULT)
243
+
244
+ try:
245
+ return do_load(repo_id, revision or "", hf_token or "", max_new_tokens)
246
+ except Exception as e:
247
+ return f"Autoload failed: {type(e).__name__}: {e}"
248
+
249
+
250
+ with gr.Blocks(title="Safetensors QA -> examples.md") as demo:
251
+ gr.Markdown(
252
+ """
253
+ # Safetensors QA → `examples.md`
254
+
255
+ This Space loads a Transformers model (preferring `.safetensors`) from a Hub repo and generates answers for a fixed list of prompts (sent **without** the `Q:` prefix).
256
+ Then it writes the results into `examples.md` in the requested `- Q:` / `- A:` format.
257
+ """
258
+ )
259
+
260
+ with gr.Accordion("Model settings", open=True):
261
+ repo_id_in = gr.Textbox(
262
+ label="MODEL_REPO_ID (Hub repo)",
263
+ value=os.getenv("MODEL_REPO_ID", ""),
264
+ placeholder='e.g. "username/my-model-repo"',
265
+ )
266
+ revision_in = gr.Textbox(
267
+ label="Revision (optional)",
268
+ value=os.getenv("REVISION", ""),
269
+ placeholder="branch / tag / commit (leave empty for default)",
270
+ )
271
+ token_in = gr.Textbox(
272
+ label="HF_TOKEN (optional, for private repos)",
273
+ value="",
274
+ placeholder="Leave empty to use Space secret HF_TOKEN",
275
+ type="password",
276
+ )
277
+ load_btn = gr.Button("Load model", variant="primary")
278
+ load_status = gr.Markdown(value=maybe_autoload())
279
+
280
+ with gr.Accordion("Generation settings", open=True):
281
+ max_new_tokens_in = gr.Slider(
282
+ label="max_new_tokens",
283
+ minimum=16,
284
+ maximum=1024,
285
+ value=_env_int("MAX_NEW_TOKENS", MAX_NEW_TOKENS_DEFAULT),
286
+ step=1,
287
+ )
288
+ temperature_in = gr.Slider(
289
+ label="temperature (0 = deterministic)",
290
+ minimum=0.0,
291
+ maximum=2.0,
292
+ value=0.0,
293
+ step=0.05,
294
+ )
295
+
296
+ gr.Markdown("## Generate `examples.md`")
297
+ gen_btn = gr.Button("Run prompts and write examples.md", variant="secondary")
298
+ md_preview = gr.Markdown(label="Preview")
299
+ md_file = gr.File(label="Download examples.md")
300
+
301
+ load_btn.click(
302
+ fn=do_load,
303
+ inputs=[repo_id_in, revision_in, token_in, max_new_tokens_in],
304
+ outputs=[load_status],
305
+ )
306
+
307
+ gen_btn.click(
308
+ fn=generate_examples,
309
+ inputs=[max_new_tokens_in, temperature_in],
310
+ outputs=[md_preview, md_file],
311
+ )
312
+
313
+ if __name__ == "__main__":
314
+ demo.launch()