mervenoyan commited on
Commit
2d3963a
·
1 Parent(s): 8844088

Switch back to Inference Providers (Qwen3.5-9B)

Browse files

Revert the ZeroGPU/transformers local-inference path. The provider has
JSON mode and 8-way parallel digest calls were noticeably faster than
the serial GPU loop. Keeps the model swap to Qwen3.5-9B, temperature=0
(greedy), max_tokens=1500 on the bulletin call, and a per-call user
reminder restating the 3-sins-and-length-budget constraints. Drops
spaces/transformers/accelerate/torch from requirements.

Files changed (3) hide show
  1. analyze.py +75 -152
  2. app.py +13 -3
  3. requirements.txt +0 -4
analyze.py CHANGED
@@ -1,101 +1,36 @@
1
- """Local Qwen3.6-35B-A3B on ZeroGPU: map (per-session digests) + reduce (bulletin)."""
2
 
3
  import datetime as dt
4
  import hashlib
5
  import json
6
- import re
 
7
 
8
- import spaces
9
- import torch
10
- from transformers import AutoModelForCausalLM, AutoTokenizer
11
 
12
  from extract import event_role, event_tool_names
13
 
14
  MODEL = "Qwen/Qwen3.5-9B"
15
 
16
- # Lazily populated inside the GPU worker on first call.
17
- _tokenizer = None
18
- _model = None
19
-
20
-
21
- def _load():
22
- """Load tokenizer + model on the GPU worker. Cached after first call.
23
-
24
- Two ZeroGPU-specific bits:
25
- - We touch CUDA once (`torch.cuda.init()` + a 1-element alloc) so the
26
- caching allocator's NVML query happens in a known-good state before
27
- transformers' loader starts hammering it per-tensor.
28
- - `low_cpu_mem_usage=True` makes the loader use meta-tensor init and
29
- stream shards onto the device, instead of materialising each tensor
30
- on CPU and then `.to("cuda")` (which is what triggered the NVML
31
- assert under the new core_model_loading path).
32
- """
33
- global _tokenizer, _model
34
- if _model is None:
35
- torch.cuda.init()
36
- _ = torch.empty(1, device="cuda")
37
- torch.cuda.synchronize()
38
-
39
- _tokenizer = AutoTokenizer.from_pretrained(MODEL)
40
- _model = AutoModelForCausalLM.from_pretrained(
41
- MODEL,
42
- torch_dtype=torch.bfloat16,
43
- device_map="cuda",
44
- low_cpu_mem_usage=True,
45
- )
46
- _model.eval()
47
- return _tokenizer, _model
48
-
49
-
50
- def _chat(
51
- tokenizer,
52
- model,
53
- messages: list[dict],
54
- *,
55
- max_new_tokens: int,
56
- temperature: float,
57
- ) -> str:
58
- text = tokenizer.apply_chat_template(
59
- messages,
60
- tokenize=False,
61
- add_generation_prompt=True,
62
- enable_thinking=False,
63
- )
64
- inputs = tokenizer(text, return_tensors="pt").to(model.device)
65
- with torch.inference_mode():
66
- out = model.generate(
67
- **inputs,
68
- max_new_tokens=max_new_tokens,
69
- temperature=temperature,
70
- do_sample=temperature > 0,
71
- pad_token_id=tokenizer.eos_token_id,
72
- )
73
- completion_ids = out[0][inputs.input_ids.shape[1]:]
74
- return tokenizer.decode(completion_ids, skip_special_tokens=True)
75
 
76
 
77
- _FENCE_RE = re.compile(r"^```(?:json)?\s*|\s*```$", re.IGNORECASE | re.MULTILINE)
78
-
79
-
80
- def _parse_json(text: str) -> dict:
81
- """Forgiving JSON parse: strip markdown fences, find the outermost {...} if needed."""
82
- text = _FENCE_RE.sub("", text.strip()).strip()
83
- try:
84
- return json.loads(text)
85
- except json.JSONDecodeError:
86
- # Fall back to the first balanced { ... } block.
87
- start = text.find("{")
88
- end = text.rfind("}")
89
- if start != -1 and end != -1 and end > start:
90
- return json.loads(text[start : end + 1])
91
- raise
92
 
93
 
94
  # ---------- map: per-session digest ----------
95
 
96
  _DIGEST_SYSTEM = """You are analysing a single coding-agent session transcript. The TRANSCRIPT shows messages between a HUMAN USER and an AGENT (the AI). Return signals about the HUMAN USER only — never about the agent.
97
 
98
- Return STRICT JSON, no prose, no markdown fences:
99
  {
100
  "session_id": <echo>,
101
  "intent": "<one sentence: what the user was trying to do>",
@@ -107,43 +42,43 @@ Return STRICT JSON, no prose, no markdown fences:
107
  Hard rules:
108
  - Only include things the user actually said or did. Do not attribute agent behaviour to the user.
109
  - top_quotes must literally appear in user messages.
110
- - Be concise and specific. No invented quotes.
111
- - Emit JSON only. No commentary."""
112
 
113
 
114
- def _digest_one(tokenizer, model, transcript: str, session_id: str) -> dict | None:
115
  user_prompt = f"session_id: {session_id}\n\nTranscript:\n{transcript}"
116
- messages = [
117
- {"role": "system", "content": _DIGEST_SYSTEM},
118
- {"role": "user", "content": user_prompt},
119
- ]
120
- for attempt in range(2):
121
- try:
122
- raw = _chat(
123
- tokenizer,
124
- model,
125
- messages,
126
- max_new_tokens=800,
127
- temperature=0.4 if attempt == 0 else 0.2,
128
- )
129
- data = _parse_json(raw)
130
- data.setdefault("session_id", session_id)
131
- return data
132
- except Exception:
133
- continue
134
- return None
135
-
136
-
137
- @spaces.GPU(duration=300)
138
- def digest_all(transcripts: list[tuple[str, str]]) -> list[dict]:
139
- """Run a digest for each (session_id, transcript) sequentially on the GPU worker."""
140
- tokenizer, model = _load()
141
- results = []
142
- for sid, text in transcripts:
143
- out = _digest_one(tokenizer, model, text, sid)
144
- if out is not None:
145
- results.append(out)
146
- return results
 
147
 
148
 
149
  # ---------- stats from raw events ----------
@@ -192,6 +127,7 @@ def serial_for(user: str) -> str:
192
 
193
  # ---------- reduce: bulletin generation ----------
194
 
 
195
  _BULLETIN_SYSTEM = """You are the Hugging Face Roastery. You read agent-trace dataset digests and write a gently savage personality bulletin about the HUMAN USER who was prompting the agent — never about the agent itself. The output is a vintage printed card; every field has a strict length budget. Be specific, be funny, never punch down.
196
 
197
  You will receive:
@@ -220,6 +156,8 @@ Field budgets (hard limits — overflow breaks the layout):
220
  - sins[].meta: 30-60 chars
221
  - forecast.body: 270-340 chars, ends with "Lucky <x>: <y>. Avoid: <z>."
222
 
 
 
223
  Voice:
224
  - Sharp but loving — group-chat energy, not insult-comic. Roast habits a thoughtful friend would call out.
225
  - Sentence case for titles. Smart quotes ( " " ), en-dashes ( – ), em-dashes ( — ). No exclamation marks. No emojis.
@@ -242,58 +180,43 @@ Procedure:
242
  7. Emit JSON only. No code fences. No commentary."""
243
 
244
 
245
- def _bulletin_valid(data: dict) -> bool:
246
- """The bulletin must have all 3 sins; budgets are best-effort and not enforced here."""
247
- sins = data.get("sins")
248
- return isinstance(sins, list) and len(sins) >= 3
249
-
250
-
251
- @spaces.GPU(duration=180)
252
- def _bulletin(digests: list[dict], user: str, dataset_id: str) -> dict:
253
- tokenizer, model = _load()
254
  user_prompt = (
255
  f"user: {user}\n"
256
  f"dataset: {dataset_id}\n\n"
257
  f"digests (JSON list):\n{json.dumps(digests, ensure_ascii=False, indent=2)}\n\n"
258
- "Reminder: respect every length budget AND emit EXACTLY 3 sins. "
259
- "Tagline must be ≤170 chars; forecast.body must be ≤340 chars. "
260
- "Output only the JSON object."
261
  )
262
- messages = [
263
- {"role": "system", "content": _BULLETIN_SYSTEM},
264
- {"role": "user", "content": user_prompt},
265
- ]
266
- last_err = None
267
- last_data = None
268
- for attempt in range(3):
269
- raw = _chat(
270
- tokenizer,
271
- model,
272
- messages,
273
- max_new_tokens=1500,
274
- temperature=0.85 if attempt == 0 else 0.4,
275
- )
276
- try:
277
- data = _parse_json(raw)
278
- except Exception as e:
279
- last_err = e
280
- continue
281
- last_data = data
282
- if _bulletin_valid(data):
283
- return data
284
- if last_data is not None:
285
- return last_data
286
- raise RuntimeError(f"Bulletin JSON parse failed: {last_err}")
287
 
288
 
289
  def build_report(
 
290
  digests: list[dict],
291
  user: str,
292
  dataset_id: str,
293
  stats: dict,
294
  ) -> dict:
295
  """Combine model output + computed stats into the full report dict for render.py."""
296
- data = _bulletin(digests, user, dataset_id)
297
  today = dt.date.today().strftime("%b %d, %Y")
298
  archetype = data.get("archetype") or ["The", "Unreadable"]
299
  if not isinstance(archetype, list) or len(archetype) < 2:
 
1
+ """InferenceClient calls: map (per-session digests) + reduce (bulletin)."""
2
 
3
  import datetime as dt
4
  import hashlib
5
  import json
6
+ import os
7
+ from concurrent.futures import ThreadPoolExecutor
8
 
9
+ from huggingface_hub import InferenceClient
 
 
10
 
11
  from extract import event_role, event_tool_names
12
 
13
  MODEL = "Qwen/Qwen3.5-9B"
14
 
15
+ _NO_THINK = {"chat_template_kwargs": {"enable_thinking": False}}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
+ def get_client(token: str | None = None) -> InferenceClient:
19
+ """Build the InferenceClient. Centralised so OAuth swap is one place."""
20
+ if token is None:
21
+ token = os.environ.get("HF_TOKEN")
22
+ if not token:
23
+ raise RuntimeError(
24
+ "HF_TOKEN is not set. Export it in your shell or pass token= explicitly."
25
+ )
26
+ return InferenceClient(model=MODEL, token=token)
 
 
 
 
 
 
27
 
28
 
29
  # ---------- map: per-session digest ----------
30
 
31
  _DIGEST_SYSTEM = """You are analysing a single coding-agent session transcript. The TRANSCRIPT shows messages between a HUMAN USER and an AGENT (the AI). Return signals about the HUMAN USER only — never about the agent.
32
 
33
+ Return STRICT JSON:
34
  {
35
  "session_id": <echo>,
36
  "intent": "<one sentence: what the user was trying to do>",
 
42
  Hard rules:
43
  - Only include things the user actually said or did. Do not attribute agent behaviour to the user.
44
  - top_quotes must literally appear in user messages.
45
+ - Be concise and specific. No invented quotes."""
 
46
 
47
 
48
+ def digest_session(client: InferenceClient, transcript: str, session_id: str) -> dict:
49
  user_prompt = f"session_id: {session_id}\n\nTranscript:\n{transcript}"
50
+ try:
51
+ resp = client.chat_completion(
52
+ messages=[
53
+ {"role": "system", "content": _DIGEST_SYSTEM},
54
+ {"role": "user", "content": user_prompt},
55
+ ],
56
+ response_format={"type": "json_object"},
57
+ max_tokens=800,
58
+ temperature=0,
59
+ extra_body=_NO_THINK,
60
+ )
61
+ raw = resp.choices[0].message.content or "{}"
62
+ data = json.loads(raw)
63
+ data.setdefault("session_id", session_id)
64
+ return data
65
+ except Exception as e:
66
+ return {"session_id": session_id, "error": str(e)}
67
+
68
+
69
+ def digest_all(
70
+ client: InferenceClient,
71
+ transcripts: list[tuple[str, str]],
72
+ max_workers: int = 8,
73
+ ) -> list[dict]:
74
+ """Run digest_session over all transcripts in parallel. Drops error entries."""
75
+ def _one(item):
76
+ sid, text = item
77
+ return digest_session(client, text, sid)
78
+
79
+ with ThreadPoolExecutor(max_workers=max_workers) as ex:
80
+ results = list(ex.map(_one, transcripts))
81
+ return [r for r in results if "error" not in r]
82
 
83
 
84
  # ---------- stats from raw events ----------
 
127
 
128
  # ---------- reduce: bulletin generation ----------
129
 
130
+ # Adapted from the design handoff's CONTENT_PROMPT.md.
131
  _BULLETIN_SYSTEM = """You are the Hugging Face Roastery. You read agent-trace dataset digests and write a gently savage personality bulletin about the HUMAN USER who was prompting the agent — never about the agent itself. The output is a vintage printed card; every field has a strict length budget. Be specific, be funny, never punch down.
132
 
133
  You will receive:
 
156
  - sins[].meta: 30-60 chars
157
  - forecast.body: 270-340 chars, ends with "Lucky <x>: <y>. Avoid: <z>."
158
 
159
+ The sins array MUST contain exactly 3 objects. Do not emit fewer.
160
+
161
  Voice:
162
  - Sharp but loving — group-chat energy, not insult-comic. Roast habits a thoughtful friend would call out.
163
  - Sentence case for titles. Smart quotes ( " " ), en-dashes ( – ), em-dashes ( — ). No exclamation marks. No emojis.
 
180
  7. Emit JSON only. No code fences. No commentary."""
181
 
182
 
183
+ def bulletin(
184
+ client: InferenceClient,
185
+ digests: list[dict],
186
+ user: str,
187
+ dataset_id: str,
188
+ ) -> dict:
189
+ """Generate the report content (archetype, tagline, sins, forecast). One JSON call."""
 
 
190
  user_prompt = (
191
  f"user: {user}\n"
192
  f"dataset: {dataset_id}\n\n"
193
  f"digests (JSON list):\n{json.dumps(digests, ensure_ascii=False, indent=2)}\n\n"
194
+ "Reminder: emit EXACTLY 3 sins and respect every length budget. "
195
+ "Tagline ≤170 chars; forecast.body ≤340 chars."
 
196
  )
197
+ resp = client.chat_completion(
198
+ messages=[
199
+ {"role": "system", "content": _BULLETIN_SYSTEM},
200
+ {"role": "user", "content": user_prompt},
201
+ ],
202
+ response_format={"type": "json_object"},
203
+ max_tokens=1500,
204
+ temperature=0,
205
+ extra_body=_NO_THINK,
206
+ )
207
+ raw = resp.choices[0].message.content or "{}"
208
+ return json.loads(raw)
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
 
211
  def build_report(
212
+ client: InferenceClient,
213
  digests: list[dict],
214
  user: str,
215
  dataset_id: str,
216
  stats: dict,
217
  ) -> dict:
218
  """Combine model output + computed stats into the full report dict for render.py."""
219
+ data = bulletin(client, digests, user, dataset_id)
220
  today = dt.date.today().strftime("%b %d, %Y")
221
  archetype = data.get("archetype") or ["The", "Unreadable"]
222
  if not isinstance(archetype, list) or len(archetype) < 2:
app.py CHANGED
@@ -5,12 +5,13 @@ endpoint below via `@gradio/client`. Report generation logic is unchanged
5
  from the original Blocks app.
6
  """
7
 
 
8
  from pathlib import Path
9
 
10
  from fastapi.responses import HTMLResponse
11
  from gradio import Server
12
 
13
- from analyze import build_report, compute_stats, digest_all
14
  from dataset import fetch_sessions, list_sessions
15
  from extract import events_to_transcript, truncate_transcript
16
  from render import bulletin_html, empty_bulletin_html
@@ -34,6 +35,12 @@ def generate_bulletin(
34
 
35
  yield "Connecting…", empty_bulletin_html("Connecting…")
36
 
 
 
 
 
 
 
37
  try:
38
  yield "Listing sessions…", empty_bulletin_html("Listing sessions…")
39
  paths = list_sessions(repo_id)
@@ -64,10 +71,10 @@ def generate_bulletin(
64
  ]
65
 
66
  yield (
67
- f"Reading {len(transcripts)} sessions on GPU…",
68
  empty_bulletin_html("Consulting the traces…"),
69
  )
70
- digests = digest_all(transcripts)
71
  if not digests:
72
  yield (
73
  "Every per-session digest failed. Try again or lower max sessions.",
@@ -83,6 +90,7 @@ def generate_bulletin(
83
  owner = _owner_from(repo_id)
84
  try:
85
  report = build_report(
 
86
  digests=digests,
87
  user=owner,
88
  dataset_id=repo_id,
@@ -112,4 +120,6 @@ async def homepage():
112
 
113
 
114
  if __name__ == "__main__":
 
 
115
  app.launch(show_error=True)
 
5
  from the original Blocks app.
6
  """
7
 
8
+ import os
9
  from pathlib import Path
10
 
11
  from fastapi.responses import HTMLResponse
12
  from gradio import Server
13
 
14
+ from analyze import build_report, compute_stats, digest_all, get_client
15
  from dataset import fetch_sessions, list_sessions
16
  from extract import events_to_transcript, truncate_transcript
17
  from render import bulletin_html, empty_bulletin_html
 
35
 
36
  yield "Connecting…", empty_bulletin_html("Connecting…")
37
 
38
+ try:
39
+ client = get_client()
40
+ except Exception as e:
41
+ yield f"❌ {e}", empty_bulletin_html("HF_TOKEN missing")
42
+ return
43
+
44
  try:
45
  yield "Listing sessions…", empty_bulletin_html("Listing sessions…")
46
  paths = list_sessions(repo_id)
 
71
  ]
72
 
73
  yield (
74
+ f"Reading {len(transcripts)} sessions in parallel…",
75
  empty_bulletin_html("Consulting the traces…"),
76
  )
77
+ digests = digest_all(client, transcripts)
78
  if not digests:
79
  yield (
80
  "Every per-session digest failed. Try again or lower max sessions.",
 
90
  owner = _owner_from(repo_id)
91
  try:
92
  report = build_report(
93
+ client=client,
94
  digests=digests,
95
  user=owner,
96
  dataset_id=repo_id,
 
120
 
121
 
122
  if __name__ == "__main__":
123
+ if not os.environ.get("HF_TOKEN"):
124
+ print("warning: HF_TOKEN not set; the app will error on the first click.")
125
  app.launch(show_error=True)
requirements.txt CHANGED
@@ -1,7 +1,3 @@
1
  gradio>=6.14
2
  huggingface_hub>=0.28
3
  Pillow>=10.0
4
- spaces
5
- transformers>=4.45
6
- accelerate>=0.30
7
- torch
 
1
  gradio>=6.14
2
  huggingface_hub>=0.28
3
  Pillow>=10.0