riazmo commited on
Commit
4ab81ad
Β·
verified Β·
1 Parent(s): 873b292

Upload stage2_graph.py

Browse files
Files changed (1) hide show
  1. agents/stage2_graph.py +837 -0
agents/stage2_graph.py ADDED
@@ -0,0 +1,837 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Stage 2 Multi-Agent Analysis Workflow (LangGraph)
3
+
4
+ Architecture:
5
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
6
+ β”‚ LLM 1 β”‚ β”‚ LLM 2 β”‚ β”‚ Rule Engine β”‚
7
+ β”‚ (Qwen) β”‚ β”‚ (Llama) β”‚ β”‚ (No LLM) β”‚
8
+ β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”˜
9
+ β”‚ β”‚ β”‚
10
+ β”‚ PARALLEL β”‚ β”‚
11
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
12
+ β”‚
13
+ β–Ό
14
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
15
+ β”‚ HEAD β”‚
16
+ β”‚ (Compiler) β”‚
17
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
18
+ """
19
+
20
+ import asyncio
21
+ import json
22
+ import os
23
+ import time
24
+ import yaml
25
+ from dataclasses import dataclass, field
26
+ from datetime import datetime
27
+ from typing import Any, Callable, Optional
28
+
29
+ from langgraph.graph import END, START, StateGraph
30
+ from typing_extensions import TypedDict
31
+
32
+ # =============================================================================
33
+ # CONFIGURATION LOADING
34
+ # =============================================================================
35
+
36
+ def load_agent_config() -> dict:
37
+ """Load agent configuration from YAML."""
38
+ config_path = os.path.join(os.path.dirname(__file__), "..", "config", "agents.yaml")
39
+ if os.path.exists(config_path):
40
+ with open(config_path, 'r') as f:
41
+ return yaml.safe_load(f)
42
+ return {}
43
+
44
+
45
+ # =============================================================================
46
+ # STATE DEFINITION
47
+ # =============================================================================
48
+
49
+ class Stage2State(TypedDict):
50
+ """State for Stage 2 multi-agent analysis."""
51
+
52
+ # Inputs
53
+ desktop_tokens: dict
54
+ mobile_tokens: dict
55
+ competitors: list[str]
56
+
57
+ # Parallel analysis outputs
58
+ llm1_analysis: Optional[dict]
59
+ llm2_analysis: Optional[dict]
60
+ rule_calculations: Optional[dict]
61
+
62
+ # HEAD output
63
+ final_recommendations: Optional[dict]
64
+
65
+ # Metadata
66
+ analysis_log: list[str]
67
+ cost_tracking: dict
68
+ errors: list[str]
69
+
70
+ # Timing
71
+ start_time: float
72
+ llm1_time: float
73
+ llm2_time: float
74
+ head_time: float
75
+
76
+
77
+ # =============================================================================
78
+ # COST TRACKING
79
+ # =============================================================================
80
+
81
+ @dataclass
82
+ class CostTracker:
83
+ """Track LLM costs during analysis."""
84
+
85
+ total_input_tokens: int = 0
86
+ total_output_tokens: int = 0
87
+ total_cost: float = 0.0
88
+ calls: list = field(default_factory=list)
89
+
90
+ def add_call(self, agent_name: str, model: str, input_tokens: int, output_tokens: int,
91
+ cost_per_m_input: float, cost_per_m_output: float, duration: float):
92
+ """Record an LLM call."""
93
+ input_cost = (input_tokens / 1_000_000) * cost_per_m_input
94
+ output_cost = (output_tokens / 1_000_000) * cost_per_m_output
95
+ total_cost = input_cost + output_cost
96
+
97
+ self.total_input_tokens += input_tokens
98
+ self.total_output_tokens += output_tokens
99
+ self.total_cost += total_cost
100
+
101
+ self.calls.append({
102
+ "agent": agent_name,
103
+ "model": model,
104
+ "input_tokens": input_tokens,
105
+ "output_tokens": output_tokens,
106
+ "cost": total_cost,
107
+ "duration": duration,
108
+ })
109
+
110
+ def to_dict(self) -> dict:
111
+ return {
112
+ "total_input_tokens": self.total_input_tokens,
113
+ "total_output_tokens": self.total_output_tokens,
114
+ "total_cost": round(self.total_cost, 6),
115
+ "calls": self.calls,
116
+ }
117
+
118
+
119
+ # Global cost tracker
120
+ cost_tracker = CostTracker()
121
+
122
+
123
+ # =============================================================================
124
+ # LLM CLIENT
125
+ # =============================================================================
126
+
127
+ async def call_llm(
128
+ agent_name: str,
129
+ model: str,
130
+ provider: str,
131
+ prompt: str,
132
+ max_tokens: int = 1500,
133
+ temperature: float = 0.4,
134
+ cost_per_m_input: float = 0.5,
135
+ cost_per_m_output: float = 0.5,
136
+ log_callback: Optional[Callable] = None,
137
+ ) -> tuple[str, int, int]:
138
+ """Call LLM via HuggingFace Inference Providers."""
139
+
140
+ start_time = time.time()
141
+
142
+ if log_callback:
143
+ log_callback(f" πŸš€ {agent_name}: Calling {model} via {provider}...")
144
+
145
+ try:
146
+ from huggingface_hub import InferenceClient
147
+
148
+ hf_token = os.environ.get("HF_TOKEN")
149
+ if not hf_token:
150
+ raise ValueError("HF_TOKEN not set")
151
+
152
+ # Initialize client with provider
153
+ # Provider is set at client level, not per-call
154
+ client = InferenceClient(
155
+ token=hf_token,
156
+ provider=provider,
157
+ )
158
+
159
+ # Call without provider argument (it's set at client level)
160
+ response = client.chat_completion(
161
+ model=model,
162
+ messages=[{"role": "user", "content": prompt}],
163
+ max_tokens=max_tokens,
164
+ temperature=temperature,
165
+ )
166
+
167
+ # Extract response
168
+ content = response.choices[0].message.content
169
+
170
+ # Estimate tokens (rough)
171
+ input_tokens = len(prompt.split()) * 1.3 # Rough estimate
172
+ output_tokens = len(content.split()) * 1.3
173
+
174
+ duration = time.time() - start_time
175
+
176
+ # Track cost
177
+ cost_tracker.add_call(
178
+ agent_name=agent_name,
179
+ model=model,
180
+ input_tokens=int(input_tokens),
181
+ output_tokens=int(output_tokens),
182
+ cost_per_m_input=cost_per_m_input,
183
+ cost_per_m_output=cost_per_m_output,
184
+ duration=duration,
185
+ )
186
+
187
+ if log_callback:
188
+ est_cost = ((input_tokens / 1_000_000) * cost_per_m_input +
189
+ (output_tokens / 1_000_000) * cost_per_m_output)
190
+ log_callback(f" βœ… {agent_name}: Complete ({duration:.1f}s, ~{int(input_tokens)} in, ~{int(output_tokens)} out)")
191
+ log_callback(f" πŸ’΅ Est. cost: ${est_cost:.4f}")
192
+
193
+ return content, int(input_tokens), int(output_tokens)
194
+
195
+ except TypeError as e:
196
+ # Fallback: If provider argument not supported, try model:provider format
197
+ if "provider" in str(e):
198
+ if log_callback:
199
+ log_callback(f" ⚠️ {agent_name}: Trying model:provider format...")
200
+
201
+ from huggingface_hub import InferenceClient
202
+
203
+ hf_token = os.environ.get("HF_TOKEN")
204
+ client = InferenceClient(token=hf_token)
205
+
206
+ # Try appending provider to model name
207
+ model_with_provider = f"{model}:{provider}"
208
+
209
+ try:
210
+ response = client.chat_completion(
211
+ model=model_with_provider,
212
+ messages=[{"role": "user", "content": prompt}],
213
+ max_tokens=max_tokens,
214
+ temperature=temperature,
215
+ )
216
+
217
+ content = response.choices[0].message.content
218
+ input_tokens = len(prompt.split()) * 1.3
219
+ output_tokens = len(content.split()) * 1.3
220
+ duration = time.time() - start_time
221
+
222
+ cost_tracker.add_call(
223
+ agent_name=agent_name,
224
+ model=model,
225
+ input_tokens=int(input_tokens),
226
+ output_tokens=int(output_tokens),
227
+ cost_per_m_input=cost_per_m_input,
228
+ cost_per_m_output=cost_per_m_output,
229
+ duration=duration,
230
+ )
231
+
232
+ if log_callback:
233
+ est_cost = ((input_tokens / 1_000_000) * cost_per_m_input +
234
+ (output_tokens / 1_000_000) * cost_per_m_output)
235
+ log_callback(f" βœ… {agent_name}: Complete ({duration:.1f}s, ~{int(input_tokens)} in, ~{int(output_tokens)} out)")
236
+ log_callback(f" πŸ’΅ Est. cost: ${est_cost:.4f}")
237
+
238
+ return content, int(input_tokens), int(output_tokens)
239
+
240
+ except Exception as e2:
241
+ # Final fallback: Try without provider
242
+ if log_callback:
243
+ log_callback(f" ⚠️ {agent_name}: Trying without provider...")
244
+
245
+ response = client.chat_completion(
246
+ model=model,
247
+ messages=[{"role": "user", "content": prompt}],
248
+ max_tokens=max_tokens,
249
+ temperature=temperature,
250
+ )
251
+
252
+ content = response.choices[0].message.content
253
+ input_tokens = len(prompt.split()) * 1.3
254
+ output_tokens = len(content.split()) * 1.3
255
+ duration = time.time() - start_time
256
+
257
+ cost_tracker.add_call(
258
+ agent_name=agent_name,
259
+ model=model,
260
+ input_tokens=int(input_tokens),
261
+ output_tokens=int(output_tokens),
262
+ cost_per_m_input=cost_per_m_input,
263
+ cost_per_m_output=cost_per_m_output,
264
+ duration=duration,
265
+ )
266
+
267
+ if log_callback:
268
+ est_cost = ((input_tokens / 1_000_000) * cost_per_m_input +
269
+ (output_tokens / 1_000_000) * cost_per_m_output)
270
+ log_callback(f" βœ… {agent_name}: Complete ({duration:.1f}s, ~{int(input_tokens)} in, ~{int(output_tokens)} out)")
271
+ log_callback(f" πŸ’΅ Est. cost: ${est_cost:.4f}")
272
+
273
+ return content, int(input_tokens), int(output_tokens)
274
+ else:
275
+ raise
276
+
277
+ except Exception as e:
278
+ duration = time.time() - start_time
279
+ if log_callback:
280
+ log_callback(f" ❌ {agent_name}: Error after {duration:.1f}s - {str(e)}")
281
+ raise
282
+
283
+
284
+ # =============================================================================
285
+ # ANALYSIS NODES
286
+ # =============================================================================
287
+
288
+ async def analyze_with_llm1(state: Stage2State, log_callback: Optional[Callable] = None) -> dict:
289
+ """LLM 1 (Qwen) analysis node."""
290
+
291
+ config = load_agent_config()
292
+ llm1_config = config.get("stage2_llm1", {})
293
+
294
+ model = llm1_config.get("model", "Qwen/Qwen2.5-72B-Instruct")
295
+ provider = llm1_config.get("provider", "novita")
296
+
297
+ if log_callback:
298
+ log_callback("")
299
+ log_callback(f"πŸ€– LLM 1: {model}")
300
+ log_callback(f" Provider: {provider}")
301
+ log_callback(f" πŸ’° Cost: ${llm1_config.get('cost_per_million_input', 0.29)}/M in, ${llm1_config.get('cost_per_million_output', 0.59)}/M out")
302
+ log_callback(f" πŸ“ Task: Typography, Colors, AA, Spacing analysis")
303
+
304
+ # Build prompt
305
+ prompt = build_analyst_prompt(
306
+ tokens_summary=summarize_tokens(state["desktop_tokens"], state["mobile_tokens"]),
307
+ competitors=state["competitors"],
308
+ persona=llm1_config.get("persona", "Senior Design Systems Architect"),
309
+ )
310
+
311
+ try:
312
+ response, in_tokens, out_tokens = await call_llm(
313
+ agent_name="LLM 1 (Qwen)",
314
+ model=model,
315
+ provider=provider,
316
+ prompt=prompt,
317
+ max_tokens=llm1_config.get("max_tokens", 1500),
318
+ temperature=llm1_config.get("temperature", 0.4),
319
+ cost_per_m_input=llm1_config.get("cost_per_million_input", 0.29),
320
+ cost_per_m_output=llm1_config.get("cost_per_million_output", 0.59),
321
+ log_callback=log_callback,
322
+ )
323
+
324
+ # Parse JSON response
325
+ analysis = parse_llm_response(response)
326
+ analysis["_meta"] = {
327
+ "model": model,
328
+ "provider": provider,
329
+ "input_tokens": in_tokens,
330
+ "output_tokens": out_tokens,
331
+ }
332
+
333
+ return {"llm1_analysis": analysis, "llm1_time": time.time()}
334
+
335
+ except Exception as e:
336
+ return {
337
+ "llm1_analysis": {"error": str(e)},
338
+ "errors": state.get("errors", []) + [f"LLM1: {str(e)}"],
339
+ "llm1_time": time.time(),
340
+ }
341
+
342
+
343
+ async def analyze_with_llm2(state: Stage2State, log_callback: Optional[Callable] = None) -> dict:
344
+ """LLM 2 (Llama) analysis node."""
345
+
346
+ config = load_agent_config()
347
+ llm2_config = config.get("stage2_llm2", {})
348
+
349
+ model = llm2_config.get("model", "meta-llama/Llama-3.3-70B-Instruct")
350
+ provider = llm2_config.get("provider", "novita")
351
+
352
+ if log_callback:
353
+ log_callback("")
354
+ log_callback(f"πŸ€– LLM 2: {model}")
355
+ log_callback(f" Provider: {provider}")
356
+ log_callback(f" πŸ’° Cost: ${llm2_config.get('cost_per_million_input', 0.59)}/M in, ${llm2_config.get('cost_per_million_output', 0.79)}/M out")
357
+ log_callback(f" πŸ“ Task: Typography, Colors, AA, Spacing analysis")
358
+
359
+ # Build prompt
360
+ prompt = build_analyst_prompt(
361
+ tokens_summary=summarize_tokens(state["desktop_tokens"], state["mobile_tokens"]),
362
+ competitors=state["competitors"],
363
+ persona=llm2_config.get("persona", "Senior Design Systems Architect"),
364
+ )
365
+
366
+ try:
367
+ response, in_tokens, out_tokens = await call_llm(
368
+ agent_name="LLM 2 (Llama)",
369
+ model=model,
370
+ provider=provider,
371
+ prompt=prompt,
372
+ max_tokens=llm2_config.get("max_tokens", 1500),
373
+ temperature=llm2_config.get("temperature", 0.4),
374
+ cost_per_m_input=llm2_config.get("cost_per_million_input", 0.59),
375
+ cost_per_m_output=llm2_config.get("cost_per_million_output", 0.79),
376
+ log_callback=log_callback,
377
+ )
378
+
379
+ # Parse JSON response
380
+ analysis = parse_llm_response(response)
381
+ analysis["_meta"] = {
382
+ "model": model,
383
+ "provider": provider,
384
+ "input_tokens": in_tokens,
385
+ "output_tokens": out_tokens,
386
+ }
387
+
388
+ return {"llm2_analysis": analysis, "llm2_time": time.time()}
389
+
390
+ except Exception as e:
391
+ return {
392
+ "llm2_analysis": {"error": str(e)},
393
+ "errors": state.get("errors", []) + [f"LLM2: {str(e)}"],
394
+ "llm2_time": time.time(),
395
+ }
396
+
397
+
398
+ def run_rule_engine(state: Stage2State, log_callback: Optional[Callable] = None) -> dict:
399
+ """Rule engine node (no LLM, always runs)."""
400
+
401
+ if log_callback:
402
+ log_callback("")
403
+ log_callback("βš™οΈ Rule Engine: Running calculations...")
404
+ log_callback(" πŸ’° Cost: FREE (no LLM)")
405
+
406
+ start = time.time()
407
+
408
+ # Calculate type scale options
409
+ base_size = detect_base_font_size(state["desktop_tokens"])
410
+ type_scales = {
411
+ "1.2": generate_type_scale(base_size, 1.2),
412
+ "1.25": generate_type_scale(base_size, 1.25),
413
+ "1.333": generate_type_scale(base_size, 1.333),
414
+ }
415
+
416
+ # Calculate spacing options
417
+ spacing_options = {
418
+ "4px": generate_spacing_scale(4),
419
+ "8px": generate_spacing_scale(8),
420
+ }
421
+
422
+ # Generate color ramps for each base color
423
+ from core.color_utils import generate_color_ramp
424
+
425
+ color_ramps = {}
426
+ colors = state["desktop_tokens"].get("colors", {})
427
+ for name, color in list(colors.items())[:8]:
428
+ hex_val = color.get("value") if isinstance(color, dict) else str(color)
429
+ try:
430
+ color_ramps[name] = generate_color_ramp(hex_val)
431
+ except:
432
+ pass
433
+
434
+ duration = time.time() - start
435
+
436
+ if log_callback:
437
+ log_callback(f" βœ… Rule Engine: Complete ({duration:.2f}s)")
438
+ log_callback(f" Generated: {len(type_scales)} type scales, {len(spacing_options)} spacing grids, {len(color_ramps)} color ramps")
439
+
440
+ return {
441
+ "rule_calculations": {
442
+ "base_font_size": base_size,
443
+ "type_scales": type_scales,
444
+ "spacing_options": spacing_options,
445
+ "color_ramps": color_ramps,
446
+ }
447
+ }
448
+
449
+
450
+ async def compile_with_head(state: Stage2State, log_callback: Optional[Callable] = None) -> dict:
451
+ """HEAD compiler node."""
452
+
453
+ config = load_agent_config()
454
+ head_config = config.get("stage2_head", {})
455
+
456
+ model = head_config.get("model", "meta-llama/Llama-3.3-70B-Instruct")
457
+ provider = head_config.get("provider", "novita")
458
+
459
+ if log_callback:
460
+ log_callback("")
461
+ log_callback("=" * 50)
462
+ log_callback("🧠 HEAD COMPILER: Synthesizing results...")
463
+ log_callback(f" Model: {model}")
464
+ log_callback(f" Provider: {provider}")
465
+ log_callback(f" πŸ’° Cost: ${head_config.get('cost_per_million_input', 0.59)}/M in, ${head_config.get('cost_per_million_output', 0.79)}/M out")
466
+
467
+ # Build HEAD prompt
468
+ prompt = build_head_prompt(
469
+ llm1_analysis=state.get("llm1_analysis", {}),
470
+ llm2_analysis=state.get("llm2_analysis", {}),
471
+ rule_calculations=state.get("rule_calculations", {}),
472
+ )
473
+
474
+ try:
475
+ response, in_tokens, out_tokens = await call_llm(
476
+ agent_name="HEAD",
477
+ model=model,
478
+ provider=provider,
479
+ prompt=prompt,
480
+ max_tokens=head_config.get("max_tokens", 2000),
481
+ temperature=head_config.get("temperature", 0.3),
482
+ cost_per_m_input=head_config.get("cost_per_million_input", 0.59),
483
+ cost_per_m_output=head_config.get("cost_per_million_output", 0.79),
484
+ log_callback=log_callback,
485
+ )
486
+
487
+ # Parse response
488
+ recommendations = parse_llm_response(response)
489
+ recommendations["_meta"] = {
490
+ "model": model,
491
+ "provider": provider,
492
+ "input_tokens": in_tokens,
493
+ "output_tokens": out_tokens,
494
+ }
495
+
496
+ # Add cost summary
497
+ recommendations["cost_summary"] = cost_tracker.to_dict()
498
+
499
+ if log_callback:
500
+ log_callback("")
501
+ log_callback("=" * 50)
502
+ log_callback(f"πŸ’° TOTAL ESTIMATED COST: ${cost_tracker.total_cost:.4f}")
503
+ log_callback(f" (Free tier: $0.10/mo | Pro: $2/mo)")
504
+ log_callback("=" * 50)
505
+
506
+ return {
507
+ "final_recommendations": recommendations,
508
+ "cost_tracking": cost_tracker.to_dict(),
509
+ "head_time": time.time(),
510
+ }
511
+
512
+ except Exception as e:
513
+ if log_callback:
514
+ log_callback(f" ❌ HEAD Error: {str(e)}")
515
+
516
+ # Fallback to rule-based recommendations
517
+ return {
518
+ "final_recommendations": build_fallback_recommendations(state),
519
+ "errors": state.get("errors", []) + [f"HEAD: {str(e)}"],
520
+ "head_time": time.time(),
521
+ }
522
+
523
+
524
+ # =============================================================================
525
+ # HELPER FUNCTIONS
526
+ # =============================================================================
527
+
528
+ def summarize_tokens(desktop: dict, mobile: dict) -> str:
529
+ """Create a summary of tokens for the prompt."""
530
+ lines = []
531
+
532
+ # Colors
533
+ colors = desktop.get("colors", {})
534
+ lines.append(f"### Colors ({len(colors)} detected)")
535
+ for name, c in list(colors.items())[:5]:
536
+ val = c.get("value") if isinstance(c, dict) else str(c)
537
+ lines.append(f"- {name}: {val}")
538
+
539
+ # Typography Desktop
540
+ typo = desktop.get("typography", {})
541
+ lines.append(f"\n### Typography Desktop ({len(typo)} styles)")
542
+ for name, t in list(typo.items())[:5]:
543
+ if isinstance(t, dict):
544
+ lines.append(f"- {name}: {t.get('font_size', '?')} / {t.get('font_weight', '?')}")
545
+
546
+ # Typography Mobile
547
+ mobile_typo = mobile.get("typography", {})
548
+ lines.append(f"\n### Typography Mobile ({len(mobile_typo)} styles)")
549
+
550
+ # Spacing
551
+ spacing = desktop.get("spacing", {})
552
+ lines.append(f"\n### Spacing ({len(spacing)} values)")
553
+
554
+ return "\n".join(lines)
555
+
556
+
557
+ def build_analyst_prompt(tokens_summary: str, competitors: list[str], persona: str) -> str:
558
+ """Build prompt for analyst LLMs."""
559
+ return f"""You are a {persona}.
560
+
561
+ ## YOUR TASK
562
+ Analyze these design tokens extracted from a website and compare against industry best practices.
563
+
564
+ ## EXTRACTED TOKENS
565
+ {tokens_summary}
566
+
567
+ ## COMPETITOR DESIGN SYSTEMS TO RESEARCH
568
+ {', '.join(competitors)}
569
+
570
+ ## ANALYZE THE FOLLOWING:
571
+
572
+ ### 1. Typography
573
+ - Is the type scale consistent? Does it follow a mathematical ratio?
574
+ - What is the detected base size?
575
+ - Compare to competitors: what ratios do they use?
576
+ - Score (1-10) and specific recommendations
577
+
578
+ ### 2. Colors
579
+ - Is the color palette cohesive?
580
+ - Are semantic colors properly defined (primary, secondary, etc.)?
581
+ - Score (1-10) and specific recommendations
582
+
583
+ ### 3. Accessibility (AA Compliance)
584
+ - What contrast issues might exist?
585
+ - Score (1-10)
586
+
587
+ ### 4. Spacing
588
+ - Is spacing consistent? Does it follow a grid (4px, 8px)?
589
+ - Score (1-10) and specific recommendations
590
+
591
+ ### 5. Overall Assessment
592
+ - Top 3 priorities for improvement
593
+
594
+ ## RESPOND IN JSON FORMAT ONLY:
595
+ ```json
596
+ {{
597
+ "typography": {{"analysis": "...", "detected_ratio": 1.2, "score": 7, "recommendations": ["..."]}},
598
+ "colors": {{"analysis": "...", "score": 6, "recommendations": ["..."]}},
599
+ "accessibility": {{"issues": ["..."], "score": 5}},
600
+ "spacing": {{"analysis": "...", "detected_base": 8, "score": 7, "recommendations": ["..."]}},
601
+ "top_3_priorities": ["...", "...", "..."],
602
+ "confidence": 85
603
+ }}
604
+ ```"""
605
+
606
+
607
+ def build_head_prompt(llm1_analysis: dict, llm2_analysis: dict, rule_calculations: dict) -> str:
608
+ """Build prompt for HEAD compiler."""
609
+ return f"""You are a Principal Design Systems Architect compiling analyses from two expert analysts.
610
+
611
+ ## ANALYST 1 FINDINGS:
612
+ {json.dumps(llm1_analysis, indent=2, default=str)[:2000]}
613
+
614
+ ## ANALYST 2 FINDINGS:
615
+ {json.dumps(llm2_analysis, indent=2, default=str)[:2000]}
616
+
617
+ ## RULE-BASED CALCULATIONS:
618
+ - Base font size: {rule_calculations.get('base_font_size', 16)}px
619
+ - Type scale options: 1.2, 1.25, 1.333
620
+ - Spacing options: 4px grid, 8px grid
621
+
622
+ ## YOUR TASK:
623
+ 1. Compare both analyst perspectives
624
+ 2. Identify agreements and disagreements
625
+ 3. Synthesize final recommendations
626
+
627
+ ## RESPOND IN JSON FORMAT ONLY:
628
+ ```json
629
+ {{
630
+ "agreements": [{{"topic": "...", "finding": "..."}}],
631
+ "disagreements": [{{"topic": "...", "resolution": "..."}}],
632
+ "final_recommendations": {{
633
+ "type_scale": "1.25",
634
+ "type_scale_rationale": "...",
635
+ "spacing_base": "8px",
636
+ "spacing_rationale": "...",
637
+ "color_improvements": ["..."],
638
+ "accessibility_fixes": ["..."]
639
+ }},
640
+ "overall_confidence": 85,
641
+ "summary": "..."
642
+ }}
643
+ ```"""
644
+
645
+
646
+ def parse_llm_response(response: str) -> dict:
647
+ """Parse JSON from LLM response."""
648
+ try:
649
+ # Try to extract JSON from markdown code block
650
+ if "```json" in response:
651
+ start = response.find("```json") + 7
652
+ end = response.find("```", start)
653
+ json_str = response[start:end].strip()
654
+ elif "```" in response:
655
+ start = response.find("```") + 3
656
+ end = response.find("```", start)
657
+ json_str = response[start:end].strip()
658
+ else:
659
+ json_str = response.strip()
660
+
661
+ return json.loads(json_str)
662
+ except:
663
+ return {"raw_response": response[:500], "parse_error": True}
664
+
665
+
666
+ def detect_base_font_size(tokens: dict) -> int:
667
+ """Detect base font size from typography tokens."""
668
+ typography = tokens.get("typography", {})
669
+
670
+ sizes = []
671
+ for t in typography.values():
672
+ if isinstance(t, dict):
673
+ size_str = str(t.get("font_size", "16px"))
674
+ try:
675
+ size = float(size_str.replace("px", "").replace("rem", "").replace("em", ""))
676
+ if 14 <= size <= 18:
677
+ sizes.append(size)
678
+ except:
679
+ pass
680
+
681
+ if sizes:
682
+ return int(max(set(sizes), key=sizes.count))
683
+ return 16
684
+
685
+
686
+ def generate_type_scale(base: int, ratio: float) -> list[int]:
687
+ """Generate type scale from base and ratio."""
688
+ # 13 levels: display.2xl down to overline
689
+ scales = []
690
+ for i in range(8, -5, -1):
691
+ size = base * (ratio ** i)
692
+ # Round to even
693
+ scales.append(int(round(size / 2) * 2))
694
+ return scales
695
+
696
+
697
+ def generate_spacing_scale(base: int) -> list[int]:
698
+ """Generate spacing scale from base."""
699
+ return [base * i for i in range(0, 17)]
700
+
701
+
702
+ def build_fallback_recommendations(state: Stage2State) -> dict:
703
+ """Build fallback recommendations if HEAD fails."""
704
+ rule_calc = state.get("rule_calculations", {})
705
+
706
+ return {
707
+ "final_recommendations": {
708
+ "type_scale": "1.25",
709
+ "type_scale_rationale": "Major Third (1.25) is industry standard",
710
+ "spacing_base": "8px",
711
+ "spacing_rationale": "8px grid provides good visual rhythm",
712
+ "color_improvements": ["Generate full ramps (50-950)"],
713
+ "accessibility_fixes": ["Review contrast ratios"],
714
+ },
715
+ "overall_confidence": 60,
716
+ "summary": "Recommendations based on rule-based analysis (LLM unavailable)",
717
+ "fallback": True,
718
+ }
719
+
720
+
721
+ # =============================================================================
722
+ # WORKFLOW BUILDER
723
+ # =============================================================================
724
+
725
+ def build_stage2_workflow():
726
+ """Build the LangGraph workflow for Stage 2."""
727
+
728
+ workflow = StateGraph(Stage2State)
729
+
730
+ # Add nodes
731
+ workflow.add_node("llm1_analyst", analyze_with_llm1)
732
+ workflow.add_node("llm2_analyst", analyze_with_llm2)
733
+ workflow.add_node("rule_engine", run_rule_engine)
734
+ workflow.add_node("head_compiler", compile_with_head)
735
+
736
+ # Parallel execution from START
737
+ workflow.add_edge(START, "llm1_analyst")
738
+ workflow.add_edge(START, "llm2_analyst")
739
+ workflow.add_edge(START, "rule_engine")
740
+
741
+ # All converge to HEAD
742
+ workflow.add_edge("llm1_analyst", "head_compiler")
743
+ workflow.add_edge("llm2_analyst", "head_compiler")
744
+ workflow.add_edge("rule_engine", "head_compiler")
745
+
746
+ # HEAD to END
747
+ workflow.add_edge("head_compiler", END)
748
+
749
+ return workflow.compile()
750
+
751
+
752
+ # =============================================================================
753
+ # MAIN RUNNER
754
+ # =============================================================================
755
+
756
+ async def run_stage2_multi_agent(
757
+ desktop_tokens: dict,
758
+ mobile_tokens: dict,
759
+ competitors: list[str],
760
+ log_callback: Optional[Callable] = None,
761
+ ) -> dict:
762
+ """Run the Stage 2 multi-agent analysis."""
763
+
764
+ global cost_tracker
765
+ cost_tracker = CostTracker() # Reset
766
+
767
+ if log_callback:
768
+ log_callback("")
769
+ log_callback("=" * 60)
770
+ log_callback("🧠 STAGE 2: MULTI-AGENT ANALYSIS")
771
+ log_callback("=" * 60)
772
+ log_callback("")
773
+ log_callback("πŸ“¦ LLM CONFIGURATION:")
774
+
775
+ config = load_agent_config()
776
+
777
+ for agent_key in ["stage2_llm1", "stage2_llm2", "stage2_head"]:
778
+ agent = config.get(agent_key, {})
779
+ log_callback(f"β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”")
780
+ log_callback(f"β”‚ {agent.get('name', agent_key)}")
781
+ log_callback(f"β”‚ Model: {agent.get('model', 'Unknown')}")
782
+ log_callback(f"β”‚ Provider: {agent.get('provider', 'novita')}")
783
+ log_callback(f"β”‚ πŸ’° Cost: ${agent.get('cost_per_million_input', 0.5)}/M in, ${agent.get('cost_per_million_output', 0.5)}/M out")
784
+ log_callback(f"β”‚ Task: {', '.join(agent.get('tasks', [])[:2])}")
785
+ log_callback(f"β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜")
786
+
787
+ log_callback("")
788
+ log_callback("πŸ”„ RUNNING PARALLEL ANALYSIS...")
789
+
790
+ # Initial state
791
+ initial_state = {
792
+ "desktop_tokens": desktop_tokens,
793
+ "mobile_tokens": mobile_tokens,
794
+ "competitors": competitors,
795
+ "llm1_analysis": None,
796
+ "llm2_analysis": None,
797
+ "rule_calculations": None,
798
+ "final_recommendations": None,
799
+ "analysis_log": [],
800
+ "cost_tracking": {},
801
+ "errors": [],
802
+ "start_time": time.time(),
803
+ "llm1_time": 0,
804
+ "llm2_time": 0,
805
+ "head_time": 0,
806
+ }
807
+
808
+ # Run parallel analysis
809
+ try:
810
+ # Run LLM1, LLM2, and Rules in parallel
811
+ results = await asyncio.gather(
812
+ analyze_with_llm1(initial_state, log_callback),
813
+ analyze_with_llm2(initial_state, log_callback),
814
+ asyncio.to_thread(run_rule_engine, initial_state, log_callback),
815
+ return_exceptions=True,
816
+ )
817
+
818
+ # Merge results
819
+ for result in results:
820
+ if isinstance(result, dict):
821
+ initial_state.update(result)
822
+ elif isinstance(result, Exception):
823
+ initial_state["errors"].append(str(result))
824
+
825
+ # Run HEAD compiler
826
+ head_result = await compile_with_head(initial_state, log_callback)
827
+ initial_state.update(head_result)
828
+
829
+ return initial_state
830
+
831
+ except Exception as e:
832
+ if log_callback:
833
+ log_callback(f"❌ Workflow error: {str(e)}")
834
+
835
+ initial_state["errors"].append(str(e))
836
+ initial_state["final_recommendations"] = build_fallback_recommendations(initial_state)
837
+ return initial_state