riazmo commited on
Commit
873b292
·
verified ·
1 Parent(s): b143d69

Delete agents/stage2_graph.py

Browse files
Files changed (1) hide show
  1. agents/stage2_graph.py +0 -751
agents/stage2_graph.py DELETED
@@ -1,751 +0,0 @@
1
- """
2
- Stage 2 Multi-Agent Analysis Workflow (LangGraph)
3
-
4
- Architecture:
5
- ┌─────────────┐ ┌─────────────┐ ┌─────────────┐
6
- │ LLM 1 │ │ LLM 2 │ │ Rule Engine │
7
- │ (Qwen) │ │ (Llama) │ │ (No LLM) │
8
- └──────┬──────┘ └──────┬──────┘ └──────┬──────┘
9
- │ │ │
10
- │ PARALLEL │ │
11
- └───────────────────┼───────────────────┘
12
-
13
-
14
- ┌─────────────────┐
15
- │ HEAD │
16
- │ (Compiler) │
17
- └─────────────────┘
18
- """
19
-
20
- import asyncio
21
- import json
22
- import os
23
- import time
24
- import yaml
25
- from dataclasses import dataclass, field
26
- from datetime import datetime
27
- from typing import Any, Callable, Optional
28
-
29
- from langgraph.graph import END, START, StateGraph
30
- from typing_extensions import TypedDict
31
-
32
- # =============================================================================
33
- # CONFIGURATION LOADING
34
- # =============================================================================
35
-
36
- def load_agent_config() -> dict:
37
- """Load agent configuration from YAML."""
38
- config_path = os.path.join(os.path.dirname(__file__), "..", "config", "agents.yaml")
39
- if os.path.exists(config_path):
40
- with open(config_path, 'r') as f:
41
- return yaml.safe_load(f)
42
- return {}
43
-
44
-
45
- # =============================================================================
46
- # STATE DEFINITION
47
- # =============================================================================
48
-
49
- class Stage2State(TypedDict):
50
- """State for Stage 2 multi-agent analysis."""
51
-
52
- # Inputs
53
- desktop_tokens: dict
54
- mobile_tokens: dict
55
- competitors: list[str]
56
-
57
- # Parallel analysis outputs
58
- llm1_analysis: Optional[dict]
59
- llm2_analysis: Optional[dict]
60
- rule_calculations: Optional[dict]
61
-
62
- # HEAD output
63
- final_recommendations: Optional[dict]
64
-
65
- # Metadata
66
- analysis_log: list[str]
67
- cost_tracking: dict
68
- errors: list[str]
69
-
70
- # Timing
71
- start_time: float
72
- llm1_time: float
73
- llm2_time: float
74
- head_time: float
75
-
76
-
77
- # =============================================================================
78
- # COST TRACKING
79
- # =============================================================================
80
-
81
- @dataclass
82
- class CostTracker:
83
- """Track LLM costs during analysis."""
84
-
85
- total_input_tokens: int = 0
86
- total_output_tokens: int = 0
87
- total_cost: float = 0.0
88
- calls: list = field(default_factory=list)
89
-
90
- def add_call(self, agent_name: str, model: str, input_tokens: int, output_tokens: int,
91
- cost_per_m_input: float, cost_per_m_output: float, duration: float):
92
- """Record an LLM call."""
93
- input_cost = (input_tokens / 1_000_000) * cost_per_m_input
94
- output_cost = (output_tokens / 1_000_000) * cost_per_m_output
95
- total_cost = input_cost + output_cost
96
-
97
- self.total_input_tokens += input_tokens
98
- self.total_output_tokens += output_tokens
99
- self.total_cost += total_cost
100
-
101
- self.calls.append({
102
- "agent": agent_name,
103
- "model": model,
104
- "input_tokens": input_tokens,
105
- "output_tokens": output_tokens,
106
- "cost": total_cost,
107
- "duration": duration,
108
- })
109
-
110
- def to_dict(self) -> dict:
111
- return {
112
- "total_input_tokens": self.total_input_tokens,
113
- "total_output_tokens": self.total_output_tokens,
114
- "total_cost": round(self.total_cost, 6),
115
- "calls": self.calls,
116
- }
117
-
118
-
119
- # Global cost tracker
120
- cost_tracker = CostTracker()
121
-
122
-
123
- # =============================================================================
124
- # LLM CLIENT
125
- # =============================================================================
126
-
127
- async def call_llm(
128
- agent_name: str,
129
- model: str,
130
- provider: str,
131
- prompt: str,
132
- max_tokens: int = 1500,
133
- temperature: float = 0.4,
134
- cost_per_m_input: float = 0.5,
135
- cost_per_m_output: float = 0.5,
136
- log_callback: Optional[Callable] = None,
137
- ) -> tuple[str, int, int]:
138
- """Call LLM via HuggingFace Inference Providers."""
139
-
140
- start_time = time.time()
141
-
142
- if log_callback:
143
- log_callback(f" 🚀 {agent_name}: Calling {model} via {provider}...")
144
-
145
- try:
146
- from huggingface_hub import InferenceClient
147
-
148
- hf_token = os.environ.get("HF_TOKEN")
149
- if not hf_token:
150
- raise ValueError("HF_TOKEN not set")
151
-
152
- client = InferenceClient(token=hf_token)
153
-
154
- # Call with provider routing
155
- response = client.chat_completion(
156
- model=model,
157
- messages=[{"role": "user", "content": prompt}],
158
- max_tokens=max_tokens,
159
- temperature=temperature,
160
- provider=provider,
161
- )
162
-
163
- # Extract response
164
- content = response.choices[0].message.content
165
-
166
- # Estimate tokens (rough)
167
- input_tokens = len(prompt.split()) * 1.3 # Rough estimate
168
- output_tokens = len(content.split()) * 1.3
169
-
170
- duration = time.time() - start_time
171
-
172
- # Track cost
173
- cost_tracker.add_call(
174
- agent_name=agent_name,
175
- model=model,
176
- input_tokens=int(input_tokens),
177
- output_tokens=int(output_tokens),
178
- cost_per_m_input=cost_per_m_input,
179
- cost_per_m_output=cost_per_m_output,
180
- duration=duration,
181
- )
182
-
183
- if log_callback:
184
- est_cost = ((input_tokens / 1_000_000) * cost_per_m_input +
185
- (output_tokens / 1_000_000) * cost_per_m_output)
186
- log_callback(f" ✅ {agent_name}: Complete ({duration:.1f}s, ~{int(input_tokens)} in, ~{int(output_tokens)} out)")
187
- log_callback(f" 💵 Est. cost: ${est_cost:.4f}")
188
-
189
- return content, int(input_tokens), int(output_tokens)
190
-
191
- except Exception as e:
192
- duration = time.time() - start_time
193
- if log_callback:
194
- log_callback(f" ❌ {agent_name}: Error after {duration:.1f}s - {str(e)}")
195
- raise
196
-
197
-
198
- # =============================================================================
199
- # ANALYSIS NODES
200
- # =============================================================================
201
-
202
- async def analyze_with_llm1(state: Stage2State, log_callback: Optional[Callable] = None) -> dict:
203
- """LLM 1 (Qwen) analysis node."""
204
-
205
- config = load_agent_config()
206
- llm1_config = config.get("stage2_llm1", {})
207
-
208
- model = llm1_config.get("model", "Qwen/Qwen2.5-72B-Instruct")
209
- provider = llm1_config.get("provider", "novita")
210
-
211
- if log_callback:
212
- log_callback("")
213
- log_callback(f"🤖 LLM 1: {model}")
214
- log_callback(f" Provider: {provider}")
215
- log_callback(f" 💰 Cost: ${llm1_config.get('cost_per_million_input', 0.29)}/M in, ${llm1_config.get('cost_per_million_output', 0.59)}/M out")
216
- log_callback(f" 📝 Task: Typography, Colors, AA, Spacing analysis")
217
-
218
- # Build prompt
219
- prompt = build_analyst_prompt(
220
- tokens_summary=summarize_tokens(state["desktop_tokens"], state["mobile_tokens"]),
221
- competitors=state["competitors"],
222
- persona=llm1_config.get("persona", "Senior Design Systems Architect"),
223
- )
224
-
225
- try:
226
- response, in_tokens, out_tokens = await call_llm(
227
- agent_name="LLM 1 (Qwen)",
228
- model=model,
229
- provider=provider,
230
- prompt=prompt,
231
- max_tokens=llm1_config.get("max_tokens", 1500),
232
- temperature=llm1_config.get("temperature", 0.4),
233
- cost_per_m_input=llm1_config.get("cost_per_million_input", 0.29),
234
- cost_per_m_output=llm1_config.get("cost_per_million_output", 0.59),
235
- log_callback=log_callback,
236
- )
237
-
238
- # Parse JSON response
239
- analysis = parse_llm_response(response)
240
- analysis["_meta"] = {
241
- "model": model,
242
- "provider": provider,
243
- "input_tokens": in_tokens,
244
- "output_tokens": out_tokens,
245
- }
246
-
247
- return {"llm1_analysis": analysis, "llm1_time": time.time()}
248
-
249
- except Exception as e:
250
- return {
251
- "llm1_analysis": {"error": str(e)},
252
- "errors": state.get("errors", []) + [f"LLM1: {str(e)}"],
253
- "llm1_time": time.time(),
254
- }
255
-
256
-
257
- async def analyze_with_llm2(state: Stage2State, log_callback: Optional[Callable] = None) -> dict:
258
- """LLM 2 (Llama) analysis node."""
259
-
260
- config = load_agent_config()
261
- llm2_config = config.get("stage2_llm2", {})
262
-
263
- model = llm2_config.get("model", "meta-llama/Llama-3.3-70B-Instruct")
264
- provider = llm2_config.get("provider", "novita")
265
-
266
- if log_callback:
267
- log_callback("")
268
- log_callback(f"🤖 LLM 2: {model}")
269
- log_callback(f" Provider: {provider}")
270
- log_callback(f" 💰 Cost: ${llm2_config.get('cost_per_million_input', 0.59)}/M in, ${llm2_config.get('cost_per_million_output', 0.79)}/M out")
271
- log_callback(f" 📝 Task: Typography, Colors, AA, Spacing analysis")
272
-
273
- # Build prompt
274
- prompt = build_analyst_prompt(
275
- tokens_summary=summarize_tokens(state["desktop_tokens"], state["mobile_tokens"]),
276
- competitors=state["competitors"],
277
- persona=llm2_config.get("persona", "Senior Design Systems Architect"),
278
- )
279
-
280
- try:
281
- response, in_tokens, out_tokens = await call_llm(
282
- agent_name="LLM 2 (Llama)",
283
- model=model,
284
- provider=provider,
285
- prompt=prompt,
286
- max_tokens=llm2_config.get("max_tokens", 1500),
287
- temperature=llm2_config.get("temperature", 0.4),
288
- cost_per_m_input=llm2_config.get("cost_per_million_input", 0.59),
289
- cost_per_m_output=llm2_config.get("cost_per_million_output", 0.79),
290
- log_callback=log_callback,
291
- )
292
-
293
- # Parse JSON response
294
- analysis = parse_llm_response(response)
295
- analysis["_meta"] = {
296
- "model": model,
297
- "provider": provider,
298
- "input_tokens": in_tokens,
299
- "output_tokens": out_tokens,
300
- }
301
-
302
- return {"llm2_analysis": analysis, "llm2_time": time.time()}
303
-
304
- except Exception as e:
305
- return {
306
- "llm2_analysis": {"error": str(e)},
307
- "errors": state.get("errors", []) + [f"LLM2: {str(e)}"],
308
- "llm2_time": time.time(),
309
- }
310
-
311
-
312
- def run_rule_engine(state: Stage2State, log_callback: Optional[Callable] = None) -> dict:
313
- """Rule engine node (no LLM, always runs)."""
314
-
315
- if log_callback:
316
- log_callback("")
317
- log_callback("⚙️ Rule Engine: Running calculations...")
318
- log_callback(" 💰 Cost: FREE (no LLM)")
319
-
320
- start = time.time()
321
-
322
- # Calculate type scale options
323
- base_size = detect_base_font_size(state["desktop_tokens"])
324
- type_scales = {
325
- "1.2": generate_type_scale(base_size, 1.2),
326
- "1.25": generate_type_scale(base_size, 1.25),
327
- "1.333": generate_type_scale(base_size, 1.333),
328
- }
329
-
330
- # Calculate spacing options
331
- spacing_options = {
332
- "4px": generate_spacing_scale(4),
333
- "8px": generate_spacing_scale(8),
334
- }
335
-
336
- # Generate color ramps for each base color
337
- from core.color_utils import generate_color_ramp
338
-
339
- color_ramps = {}
340
- colors = state["desktop_tokens"].get("colors", {})
341
- for name, color in list(colors.items())[:8]:
342
- hex_val = color.get("value") if isinstance(color, dict) else str(color)
343
- try:
344
- color_ramps[name] = generate_color_ramp(hex_val)
345
- except:
346
- pass
347
-
348
- duration = time.time() - start
349
-
350
- if log_callback:
351
- log_callback(f" ✅ Rule Engine: Complete ({duration:.2f}s)")
352
- log_callback(f" Generated: {len(type_scales)} type scales, {len(spacing_options)} spacing grids, {len(color_ramps)} color ramps")
353
-
354
- return {
355
- "rule_calculations": {
356
- "base_font_size": base_size,
357
- "type_scales": type_scales,
358
- "spacing_options": spacing_options,
359
- "color_ramps": color_ramps,
360
- }
361
- }
362
-
363
-
364
- async def compile_with_head(state: Stage2State, log_callback: Optional[Callable] = None) -> dict:
365
- """HEAD compiler node."""
366
-
367
- config = load_agent_config()
368
- head_config = config.get("stage2_head", {})
369
-
370
- model = head_config.get("model", "meta-llama/Llama-3.3-70B-Instruct")
371
- provider = head_config.get("provider", "novita")
372
-
373
- if log_callback:
374
- log_callback("")
375
- log_callback("=" * 50)
376
- log_callback("🧠 HEAD COMPILER: Synthesizing results...")
377
- log_callback(f" Model: {model}")
378
- log_callback(f" Provider: {provider}")
379
- log_callback(f" 💰 Cost: ${head_config.get('cost_per_million_input', 0.59)}/M in, ${head_config.get('cost_per_million_output', 0.79)}/M out")
380
-
381
- # Build HEAD prompt
382
- prompt = build_head_prompt(
383
- llm1_analysis=state.get("llm1_analysis", {}),
384
- llm2_analysis=state.get("llm2_analysis", {}),
385
- rule_calculations=state.get("rule_calculations", {}),
386
- )
387
-
388
- try:
389
- response, in_tokens, out_tokens = await call_llm(
390
- agent_name="HEAD",
391
- model=model,
392
- provider=provider,
393
- prompt=prompt,
394
- max_tokens=head_config.get("max_tokens", 2000),
395
- temperature=head_config.get("temperature", 0.3),
396
- cost_per_m_input=head_config.get("cost_per_million_input", 0.59),
397
- cost_per_m_output=head_config.get("cost_per_million_output", 0.79),
398
- log_callback=log_callback,
399
- )
400
-
401
- # Parse response
402
- recommendations = parse_llm_response(response)
403
- recommendations["_meta"] = {
404
- "model": model,
405
- "provider": provider,
406
- "input_tokens": in_tokens,
407
- "output_tokens": out_tokens,
408
- }
409
-
410
- # Add cost summary
411
- recommendations["cost_summary"] = cost_tracker.to_dict()
412
-
413
- if log_callback:
414
- log_callback("")
415
- log_callback("=" * 50)
416
- log_callback(f"💰 TOTAL ESTIMATED COST: ${cost_tracker.total_cost:.4f}")
417
- log_callback(f" (Free tier: $0.10/mo | Pro: $2/mo)")
418
- log_callback("=" * 50)
419
-
420
- return {
421
- "final_recommendations": recommendations,
422
- "cost_tracking": cost_tracker.to_dict(),
423
- "head_time": time.time(),
424
- }
425
-
426
- except Exception as e:
427
- if log_callback:
428
- log_callback(f" ❌ HEAD Error: {str(e)}")
429
-
430
- # Fallback to rule-based recommendations
431
- return {
432
- "final_recommendations": build_fallback_recommendations(state),
433
- "errors": state.get("errors", []) + [f"HEAD: {str(e)}"],
434
- "head_time": time.time(),
435
- }
436
-
437
-
438
- # =============================================================================
439
- # HELPER FUNCTIONS
440
- # =============================================================================
441
-
442
- def summarize_tokens(desktop: dict, mobile: dict) -> str:
443
- """Create a summary of tokens for the prompt."""
444
- lines = []
445
-
446
- # Colors
447
- colors = desktop.get("colors", {})
448
- lines.append(f"### Colors ({len(colors)} detected)")
449
- for name, c in list(colors.items())[:5]:
450
- val = c.get("value") if isinstance(c, dict) else str(c)
451
- lines.append(f"- {name}: {val}")
452
-
453
- # Typography Desktop
454
- typo = desktop.get("typography", {})
455
- lines.append(f"\n### Typography Desktop ({len(typo)} styles)")
456
- for name, t in list(typo.items())[:5]:
457
- if isinstance(t, dict):
458
- lines.append(f"- {name}: {t.get('font_size', '?')} / {t.get('font_weight', '?')}")
459
-
460
- # Typography Mobile
461
- mobile_typo = mobile.get("typography", {})
462
- lines.append(f"\n### Typography Mobile ({len(mobile_typo)} styles)")
463
-
464
- # Spacing
465
- spacing = desktop.get("spacing", {})
466
- lines.append(f"\n### Spacing ({len(spacing)} values)")
467
-
468
- return "\n".join(lines)
469
-
470
-
471
- def build_analyst_prompt(tokens_summary: str, competitors: list[str], persona: str) -> str:
472
- """Build prompt for analyst LLMs."""
473
- return f"""You are a {persona}.
474
-
475
- ## YOUR TASK
476
- Analyze these design tokens extracted from a website and compare against industry best practices.
477
-
478
- ## EXTRACTED TOKENS
479
- {tokens_summary}
480
-
481
- ## COMPETITOR DESIGN SYSTEMS TO RESEARCH
482
- {', '.join(competitors)}
483
-
484
- ## ANALYZE THE FOLLOWING:
485
-
486
- ### 1. Typography
487
- - Is the type scale consistent? Does it follow a mathematical ratio?
488
- - What is the detected base size?
489
- - Compare to competitors: what ratios do they use?
490
- - Score (1-10) and specific recommendations
491
-
492
- ### 2. Colors
493
- - Is the color palette cohesive?
494
- - Are semantic colors properly defined (primary, secondary, etc.)?
495
- - Score (1-10) and specific recommendations
496
-
497
- ### 3. Accessibility (AA Compliance)
498
- - What contrast issues might exist?
499
- - Score (1-10)
500
-
501
- ### 4. Spacing
502
- - Is spacing consistent? Does it follow a grid (4px, 8px)?
503
- - Score (1-10) and specific recommendations
504
-
505
- ### 5. Overall Assessment
506
- - Top 3 priorities for improvement
507
-
508
- ## RESPOND IN JSON FORMAT ONLY:
509
- ```json
510
- {{
511
- "typography": {{"analysis": "...", "detected_ratio": 1.2, "score": 7, "recommendations": ["..."]}},
512
- "colors": {{"analysis": "...", "score": 6, "recommendations": ["..."]}},
513
- "accessibility": {{"issues": ["..."], "score": 5}},
514
- "spacing": {{"analysis": "...", "detected_base": 8, "score": 7, "recommendations": ["..."]}},
515
- "top_3_priorities": ["...", "...", "..."],
516
- "confidence": 85
517
- }}
518
- ```"""
519
-
520
-
521
- def build_head_prompt(llm1_analysis: dict, llm2_analysis: dict, rule_calculations: dict) -> str:
522
- """Build prompt for HEAD compiler."""
523
- return f"""You are a Principal Design Systems Architect compiling analyses from two expert analysts.
524
-
525
- ## ANALYST 1 FINDINGS:
526
- {json.dumps(llm1_analysis, indent=2, default=str)[:2000]}
527
-
528
- ## ANALYST 2 FINDINGS:
529
- {json.dumps(llm2_analysis, indent=2, default=str)[:2000]}
530
-
531
- ## RULE-BASED CALCULATIONS:
532
- - Base font size: {rule_calculations.get('base_font_size', 16)}px
533
- - Type scale options: 1.2, 1.25, 1.333
534
- - Spacing options: 4px grid, 8px grid
535
-
536
- ## YOUR TASK:
537
- 1. Compare both analyst perspectives
538
- 2. Identify agreements and disagreements
539
- 3. Synthesize final recommendations
540
-
541
- ## RESPOND IN JSON FORMAT ONLY:
542
- ```json
543
- {{
544
- "agreements": [{{"topic": "...", "finding": "..."}}],
545
- "disagreements": [{{"topic": "...", "resolution": "..."}}],
546
- "final_recommendations": {{
547
- "type_scale": "1.25",
548
- "type_scale_rationale": "...",
549
- "spacing_base": "8px",
550
- "spacing_rationale": "...",
551
- "color_improvements": ["..."],
552
- "accessibility_fixes": ["..."]
553
- }},
554
- "overall_confidence": 85,
555
- "summary": "..."
556
- }}
557
- ```"""
558
-
559
-
560
- def parse_llm_response(response: str) -> dict:
561
- """Parse JSON from LLM response."""
562
- try:
563
- # Try to extract JSON from markdown code block
564
- if "```json" in response:
565
- start = response.find("```json") + 7
566
- end = response.find("```", start)
567
- json_str = response[start:end].strip()
568
- elif "```" in response:
569
- start = response.find("```") + 3
570
- end = response.find("```", start)
571
- json_str = response[start:end].strip()
572
- else:
573
- json_str = response.strip()
574
-
575
- return json.loads(json_str)
576
- except:
577
- return {"raw_response": response[:500], "parse_error": True}
578
-
579
-
580
- def detect_base_font_size(tokens: dict) -> int:
581
- """Detect base font size from typography tokens."""
582
- typography = tokens.get("typography", {})
583
-
584
- sizes = []
585
- for t in typography.values():
586
- if isinstance(t, dict):
587
- size_str = str(t.get("font_size", "16px"))
588
- try:
589
- size = float(size_str.replace("px", "").replace("rem", "").replace("em", ""))
590
- if 14 <= size <= 18:
591
- sizes.append(size)
592
- except:
593
- pass
594
-
595
- if sizes:
596
- return int(max(set(sizes), key=sizes.count))
597
- return 16
598
-
599
-
600
- def generate_type_scale(base: int, ratio: float) -> list[int]:
601
- """Generate type scale from base and ratio."""
602
- # 13 levels: display.2xl down to overline
603
- scales = []
604
- for i in range(8, -5, -1):
605
- size = base * (ratio ** i)
606
- # Round to even
607
- scales.append(int(round(size / 2) * 2))
608
- return scales
609
-
610
-
611
- def generate_spacing_scale(base: int) -> list[int]:
612
- """Generate spacing scale from base."""
613
- return [base * i for i in range(0, 17)]
614
-
615
-
616
- def build_fallback_recommendations(state: Stage2State) -> dict:
617
- """Build fallback recommendations if HEAD fails."""
618
- rule_calc = state.get("rule_calculations", {})
619
-
620
- return {
621
- "final_recommendations": {
622
- "type_scale": "1.25",
623
- "type_scale_rationale": "Major Third (1.25) is industry standard",
624
- "spacing_base": "8px",
625
- "spacing_rationale": "8px grid provides good visual rhythm",
626
- "color_improvements": ["Generate full ramps (50-950)"],
627
- "accessibility_fixes": ["Review contrast ratios"],
628
- },
629
- "overall_confidence": 60,
630
- "summary": "Recommendations based on rule-based analysis (LLM unavailable)",
631
- "fallback": True,
632
- }
633
-
634
-
635
- # =============================================================================
636
- # WORKFLOW BUILDER
637
- # =============================================================================
638
-
639
- def build_stage2_workflow():
640
- """Build the LangGraph workflow for Stage 2."""
641
-
642
- workflow = StateGraph(Stage2State)
643
-
644
- # Add nodes
645
- workflow.add_node("llm1_analyst", analyze_with_llm1)
646
- workflow.add_node("llm2_analyst", analyze_with_llm2)
647
- workflow.add_node("rule_engine", run_rule_engine)
648
- workflow.add_node("head_compiler", compile_with_head)
649
-
650
- # Parallel execution from START
651
- workflow.add_edge(START, "llm1_analyst")
652
- workflow.add_edge(START, "llm2_analyst")
653
- workflow.add_edge(START, "rule_engine")
654
-
655
- # All converge to HEAD
656
- workflow.add_edge("llm1_analyst", "head_compiler")
657
- workflow.add_edge("llm2_analyst", "head_compiler")
658
- workflow.add_edge("rule_engine", "head_compiler")
659
-
660
- # HEAD to END
661
- workflow.add_edge("head_compiler", END)
662
-
663
- return workflow.compile()
664
-
665
-
666
- # =============================================================================
667
- # MAIN RUNNER
668
- # =============================================================================
669
-
670
- async def run_stage2_multi_agent(
671
- desktop_tokens: dict,
672
- mobile_tokens: dict,
673
- competitors: list[str],
674
- log_callback: Optional[Callable] = None,
675
- ) -> dict:
676
- """Run the Stage 2 multi-agent analysis."""
677
-
678
- global cost_tracker
679
- cost_tracker = CostTracker() # Reset
680
-
681
- if log_callback:
682
- log_callback("")
683
- log_callback("=" * 60)
684
- log_callback("🧠 STAGE 2: MULTI-AGENT ANALYSIS")
685
- log_callback("=" * 60)
686
- log_callback("")
687
- log_callback("📦 LLM CONFIGURATION:")
688
-
689
- config = load_agent_config()
690
-
691
- for agent_key in ["stage2_llm1", "stage2_llm2", "stage2_head"]:
692
- agent = config.get(agent_key, {})
693
- log_callback(f"┌─────────────────────────────────────────────────────┐")
694
- log_callback(f"│ {agent.get('name', agent_key)}")
695
- log_callback(f"│ Model: {agent.get('model', 'Unknown')}")
696
- log_callback(f"│ Provider: {agent.get('provider', 'novita')}")
697
- log_callback(f"│ 💰 Cost: ${agent.get('cost_per_million_input', 0.5)}/M in, ${agent.get('cost_per_million_output', 0.5)}/M out")
698
- log_callback(f"│ Task: {', '.join(agent.get('tasks', [])[:2])}")
699
- log_callback(f"└─────────────────────────────────────────────────────┘")
700
-
701
- log_callback("")
702
- log_callback("🔄 RUNNING PARALLEL ANALYSIS...")
703
-
704
- # Initial state
705
- initial_state = {
706
- "desktop_tokens": desktop_tokens,
707
- "mobile_tokens": mobile_tokens,
708
- "competitors": competitors,
709
- "llm1_analysis": None,
710
- "llm2_analysis": None,
711
- "rule_calculations": None,
712
- "final_recommendations": None,
713
- "analysis_log": [],
714
- "cost_tracking": {},
715
- "errors": [],
716
- "start_time": time.time(),
717
- "llm1_time": 0,
718
- "llm2_time": 0,
719
- "head_time": 0,
720
- }
721
-
722
- # Run parallel analysis
723
- try:
724
- # Run LLM1, LLM2, and Rules in parallel
725
- results = await asyncio.gather(
726
- analyze_with_llm1(initial_state, log_callback),
727
- analyze_with_llm2(initial_state, log_callback),
728
- asyncio.to_thread(run_rule_engine, initial_state, log_callback),
729
- return_exceptions=True,
730
- )
731
-
732
- # Merge results
733
- for result in results:
734
- if isinstance(result, dict):
735
- initial_state.update(result)
736
- elif isinstance(result, Exception):
737
- initial_state["errors"].append(str(result))
738
-
739
- # Run HEAD compiler
740
- head_result = await compile_with_head(initial_state, log_callback)
741
- initial_state.update(head_result)
742
-
743
- return initial_state
744
-
745
- except Exception as e:
746
- if log_callback:
747
- log_callback(f"❌ Workflow error: {str(e)}")
748
-
749
- initial_state["errors"].append(str(e))
750
- initial_state["final_recommendations"] = build_fallback_recommendations(initial_state)
751
- return initial_state