RayMelius Claude Opus 4.6 commited on
Commit
af7b74e
·
1 Parent(s): 4784d87

Add Groq LLM provider and fix speed controls for real fast-forward

Browse files

- Add GroqClient for fast parallel cloud inference (free tier 30 req/min)
- Auto-detect: Claude -> Groq -> Ollama based on API keys
- Speed controls now actually affect simulation speed:
- 5x: limits to 2 conversations/tick
- 10x: limits to 1 conversation + 1 reflection/tick
- 50x: pure routine mode, zero LLM calls, instant ticks
- Skip sleep delay entirely at high speeds

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

.env.example CHANGED
@@ -1,9 +1,13 @@
1
- # LLM Provider: "claude" or "ollama" (auto-detects if not set)
2
  # LLM_PROVIDER=ollama
3
 
4
  # For Claude (paid API):
5
  # ANTHROPIC_API_KEY=sk-ant-api03-your-key-here
6
 
 
 
 
 
7
  # For Ollama (free, local):
8
  # Install: https://ollama.com
9
  # Then: ollama pull llama3.1
 
1
+ # LLM Provider: "claude", "groq", or "ollama" (auto-detects if not set)
2
  # LLM_PROVIDER=ollama
3
 
4
  # For Claude (paid API):
5
  # ANTHROPIC_API_KEY=sk-ant-api03-your-key-here
6
 
7
+ # For Groq (fast cloud, free tier 30 req/min):
8
+ # Sign up: https://console.groq.com
9
+ # GROQ_API_KEY=gsk_your-key-here
10
+
11
  # For Ollama (free, local):
12
  # Install: https://ollama.com
13
  # Then: ollama pull llama3.1
main.py CHANGED
@@ -231,8 +231,8 @@ def main():
231
  parser.add_argument("--resume", action="store_true", help="Resume from last save")
232
  parser.add_argument("--generate", action="store_true",
233
  help="Generate procedural agents to fill up to --agents count")
234
- parser.add_argument("--provider", type=str, default="", choices=["", "claude", "ollama"],
235
- help="LLM provider: claude or ollama (default: auto-detect)")
236
  parser.add_argument("--model", type=str, default="",
237
  help="Model name (e.g. llama3.1:8b, mistral, qwen2.5)")
238
  args = parser.parse_args()
 
231
  parser.add_argument("--resume", action="store_true", help="Resume from last save")
232
  parser.add_argument("--generate", action="store_true",
233
  help="Generate procedural agents to fill up to --agents count")
234
+ parser.add_argument("--provider", type=str, default="", choices=["", "claude", "groq", "ollama"],
235
+ help="LLM provider: claude, groq, or ollama (default: auto-detect)")
236
  parser.add_argument("--model", type=str, default="",
237
  help="Model name (e.g. llama3.1:8b, mistral, qwen2.5)")
238
  args = parser.parse_args()
src/soci/api/server.py CHANGED
@@ -50,11 +50,34 @@ async def simulation_loop(sim: Simulation, db: Database, tick_delay: float = 2.0
50
  if _sim_paused:
51
  await asyncio.sleep(0.5)
52
  continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  await sim.tick()
 
54
  # Auto-save every 24 ticks
55
  if sim.clock.total_ticks % 24 == 0:
56
  await save_simulation(sim, db, "autosave")
57
- await asyncio.sleep(tick_delay * _sim_speed)
 
 
 
 
 
 
58
  except asyncio.CancelledError:
59
  logger.info("Simulation loop cancelled")
60
  await save_simulation(sim, db, "autosave")
 
50
  if _sim_paused:
51
  await asyncio.sleep(0.5)
52
  continue
53
+
54
+ # At high speeds, limit LLM calls to keep ticks fast
55
+ # _sim_speed < 0.2 means 5x+, so cap concurrent conversations
56
+ if _sim_speed <= 0.05:
57
+ # 50x: skip LLM entirely, pure routine mode
58
+ sim._skip_llm_this_tick = True
59
+ elif _sim_speed <= 0.15:
60
+ # 10x: max 1 conversation per tick
61
+ sim._max_convos_this_tick = 1
62
+ elif _sim_speed <= 0.35:
63
+ # 5x: max 2 conversations per tick
64
+ sim._max_convos_this_tick = 2
65
+ else:
66
+ sim._skip_llm_this_tick = False
67
+ sim._max_convos_this_tick = 0 # 0 = no limit
68
+
69
  await sim.tick()
70
+
71
  # Auto-save every 24 ticks
72
  if sim.clock.total_ticks % 24 == 0:
73
  await save_simulation(sim, db, "autosave")
74
+
75
+ # At high speeds, skip the delay entirely
76
+ delay = tick_delay * _sim_speed
77
+ if delay > 0.05:
78
+ await asyncio.sleep(delay)
79
+ else:
80
+ await asyncio.sleep(0) # Yield to event loop
81
  except asyncio.CancelledError:
82
  logger.info("Simulation loop cancelled")
83
  await save_simulation(sim, db, "autosave")
src/soci/engine/llm.py CHANGED
@@ -1,4 +1,4 @@
1
- """LLM client — supports Claude API and Ollama (local LLMs) with model routing and cost tracking."""
2
 
3
  from __future__ import annotations
4
 
@@ -17,6 +17,7 @@ logger = logging.getLogger(__name__)
17
  # --- Provider constants ---
18
  PROVIDER_CLAUDE = "claude"
19
  PROVIDER_OLLAMA = "ollama"
 
20
 
21
  # Claude model IDs
22
  MODEL_SONNET = "claude-sonnet-4-5-20250929"
@@ -29,10 +30,18 @@ MODEL_MISTRAL = "mistral"
29
  MODEL_QWEN = "qwen2.5"
30
  MODEL_GEMMA = "gemma2"
31
 
32
- # Approximate cost per 1M tokens (USD) Ollama is free
 
 
 
 
 
33
  COST_PER_1M = {
34
  MODEL_SONNET: {"input": 3.0, "output": 15.0},
35
  MODEL_HAIKU: {"input": 0.80, "output": 4.0},
 
 
 
36
  }
37
 
38
 
@@ -346,6 +355,168 @@ class OllamaClient:
346
  return mapping.get(model, model)
347
 
348
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  # ============================================================
350
  # Factory — create the right client based on config
351
  # ============================================================
@@ -354,33 +525,39 @@ def create_llm_client(
354
  provider: Optional[str] = None,
355
  model: Optional[str] = None,
356
  ollama_url: str = "http://localhost:11434",
357
- ) -> ClaudeClient | OllamaClient:
358
  """Create an LLM client based on environment or explicit config.
359
 
360
  Provider detection order:
361
  1. Explicit provider argument
362
  2. LLM_PROVIDER env var
363
  3. If ANTHROPIC_API_KEY is set → Claude
364
- 4. DefaultOllama (free, local)
 
365
  """
366
  if provider is None:
367
  provider = os.environ.get("LLM_PROVIDER", "").lower()
368
 
369
  if not provider:
370
- # Auto-detect: use Claude if key is set, otherwise Ollama
371
  if os.environ.get("ANTHROPIC_API_KEY"):
372
  provider = PROVIDER_CLAUDE
 
 
373
  else:
374
  provider = PROVIDER_OLLAMA
375
 
376
  if provider == PROVIDER_CLAUDE:
377
  default_model = model or MODEL_HAIKU
378
  return ClaudeClient(default_model=default_model)
 
 
 
379
  elif provider == PROVIDER_OLLAMA:
380
  default_model = model or MODEL_LLAMA
381
  return OllamaClient(base_url=ollama_url, default_model=default_model)
382
  else:
383
- raise ValueError(f"Unknown LLM provider: {provider}. Use 'claude' or 'ollama'.")
384
 
385
 
386
  # --- Prompt Templates ---
 
1
+ """LLM client — supports Claude API, Groq, and Ollama (local LLMs) with model routing and cost tracking."""
2
 
3
  from __future__ import annotations
4
 
 
17
  # --- Provider constants ---
18
  PROVIDER_CLAUDE = "claude"
19
  PROVIDER_OLLAMA = "ollama"
20
+ PROVIDER_GROQ = "groq"
21
 
22
  # Claude model IDs
23
  MODEL_SONNET = "claude-sonnet-4-5-20250929"
 
30
  MODEL_QWEN = "qwen2.5"
31
  MODEL_GEMMA = "gemma2"
32
 
33
+ # Groq model IDs (fast cloud inference)
34
+ MODEL_GROQ_LLAMA_8B = "llama-3.1-8b-instant"
35
+ MODEL_GROQ_LLAMA_70B = "llama-3.3-70b-versatile"
36
+ MODEL_GROQ_MIXTRAL = "mixtral-8x7b-32768"
37
+
38
+ # Approximate cost per 1M tokens (USD) — Ollama is free, Groq is very cheap
39
  COST_PER_1M = {
40
  MODEL_SONNET: {"input": 3.0, "output": 15.0},
41
  MODEL_HAIKU: {"input": 0.80, "output": 4.0},
42
+ MODEL_GROQ_LLAMA_8B: {"input": 0.05, "output": 0.08},
43
+ MODEL_GROQ_LLAMA_70B: {"input": 0.59, "output": 0.79},
44
+ MODEL_GROQ_MIXTRAL: {"input": 0.24, "output": 0.24},
45
  }
46
 
47
 
 
355
  return mapping.get(model, model)
356
 
357
 
358
+ # ============================================================
359
+ # Groq (Fast Cloud Inference) Client
360
+ # ============================================================
361
+
362
+ class GroqClient:
363
+ """Wrapper around the Groq API for fast cloud inference.
364
+
365
+ Groq provides extremely fast inference (~500 tok/s) with parallel request support.
366
+ Free tier: 30 requests/min on llama-3.1-8b-instant.
367
+ Sign up: https://console.groq.com
368
+ """
369
+
370
+ def __init__(
371
+ self,
372
+ api_key: Optional[str] = None,
373
+ default_model: str = MODEL_GROQ_LLAMA_8B,
374
+ max_retries: int = 3,
375
+ ) -> None:
376
+ self.api_key = api_key or os.environ.get("GROQ_API_KEY", "")
377
+ if not self.api_key:
378
+ raise ValueError(
379
+ "GROQ_API_KEY not set. Get a free key at https://console.groq.com"
380
+ )
381
+ self.default_model = default_model
382
+ self.max_retries = max_retries
383
+ self.usage = LLMUsage()
384
+ self.provider = PROVIDER_GROQ
385
+ self._http = httpx.AsyncClient(
386
+ base_url="https://api.groq.com/openai/v1",
387
+ headers={
388
+ "Authorization": f"Bearer {self.api_key}",
389
+ "Content-Type": "application/json",
390
+ },
391
+ timeout=60.0,
392
+ )
393
+
394
+ async def complete(
395
+ self,
396
+ system: str,
397
+ user_message: str,
398
+ model: Optional[str] = None,
399
+ temperature: float = 0.7,
400
+ max_tokens: int = 1024,
401
+ ) -> str:
402
+ """Send a chat completion request to Groq (async, parallel-safe)."""
403
+ model = self._map_model(model or self.default_model)
404
+
405
+ payload = {
406
+ "model": model,
407
+ "messages": [
408
+ {"role": "system", "content": system},
409
+ {"role": "user", "content": user_message},
410
+ ],
411
+ "temperature": temperature,
412
+ "max_tokens": max_tokens,
413
+ }
414
+
415
+ for attempt in range(self.max_retries):
416
+ try:
417
+ response = await self._http.post("/chat/completions", json=payload)
418
+ response.raise_for_status()
419
+ data = response.json()
420
+
421
+ usage = data.get("usage", {})
422
+ self.usage.record(
423
+ model,
424
+ usage.get("prompt_tokens", 0),
425
+ usage.get("completion_tokens", 0),
426
+ )
427
+
428
+ return data["choices"][0]["message"]["content"]
429
+
430
+ except httpx.HTTPStatusError as e:
431
+ if e.response.status_code == 429:
432
+ # Rate limited — wait and retry
433
+ wait = 2 ** attempt + 1
434
+ logger.warning(f"Groq rate limited, waiting {wait}s (attempt {attempt + 1})")
435
+ await asyncio.sleep(wait)
436
+ elif e.response.status_code == 401:
437
+ raise ValueError("Invalid GROQ_API_KEY")
438
+ else:
439
+ logger.error(f"Groq API error: {e.response.status_code} {e.response.text[:200]}")
440
+ if attempt == self.max_retries - 1:
441
+ raise
442
+ await asyncio.sleep(1)
443
+ except Exception as e:
444
+ logger.error(f"Groq error: {e}")
445
+ if attempt == self.max_retries - 1:
446
+ raise
447
+ await asyncio.sleep(1)
448
+ return ""
449
+
450
+ async def complete_json(
451
+ self,
452
+ system: str,
453
+ user_message: str,
454
+ model: Optional[str] = None,
455
+ temperature: float = 0.7,
456
+ max_tokens: int = 1024,
457
+ ) -> dict:
458
+ """Send a JSON-mode request to Groq."""
459
+ model = self._map_model(model or self.default_model)
460
+
461
+ json_instruction = (
462
+ "\n\nRespond ONLY with valid JSON. No markdown, no explanation, no extra text. "
463
+ "Just the JSON object."
464
+ )
465
+
466
+ payload = {
467
+ "model": model,
468
+ "messages": [
469
+ {"role": "system", "content": system},
470
+ {"role": "user", "content": user_message + json_instruction},
471
+ ],
472
+ "temperature": temperature,
473
+ "max_tokens": max_tokens,
474
+ "response_format": {"type": "json_object"},
475
+ }
476
+
477
+ for attempt in range(self.max_retries):
478
+ try:
479
+ response = await self._http.post("/chat/completions", json=payload)
480
+ response.raise_for_status()
481
+ data = response.json()
482
+
483
+ usage = data.get("usage", {})
484
+ self.usage.record(
485
+ model,
486
+ usage.get("prompt_tokens", 0),
487
+ usage.get("completion_tokens", 0),
488
+ )
489
+
490
+ text = data["choices"][0]["message"]["content"]
491
+ return _parse_json_response(text)
492
+
493
+ except httpx.HTTPStatusError as e:
494
+ if e.response.status_code == 429:
495
+ wait = 2 ** attempt + 1
496
+ logger.warning(f"Groq rate limited, waiting {wait}s")
497
+ await asyncio.sleep(wait)
498
+ else:
499
+ logger.error(f"Groq JSON error: {e.response.status_code}")
500
+ if attempt == self.max_retries - 1:
501
+ return {}
502
+ await asyncio.sleep(1)
503
+ except Exception as e:
504
+ logger.error(f"Groq JSON error: {e}")
505
+ if attempt == self.max_retries - 1:
506
+ return {}
507
+ await asyncio.sleep(1)
508
+ return {}
509
+
510
+ def _map_model(self, model: str) -> str:
511
+ """Map Claude/Ollama model names to Groq equivalents."""
512
+ mapping = {
513
+ MODEL_SONNET: MODEL_GROQ_LLAMA_70B, # Use 70B for "smart" model
514
+ MODEL_HAIKU: self.default_model, # Use default (8B) for routine
515
+ MODEL_LLAMA: MODEL_GROQ_LLAMA_8B,
516
+ }
517
+ return mapping.get(model, model)
518
+
519
+
520
  # ============================================================
521
  # Factory — create the right client based on config
522
  # ============================================================
 
525
  provider: Optional[str] = None,
526
  model: Optional[str] = None,
527
  ollama_url: str = "http://localhost:11434",
528
+ ) -> ClaudeClient | OllamaClient | GroqClient:
529
  """Create an LLM client based on environment or explicit config.
530
 
531
  Provider detection order:
532
  1. Explicit provider argument
533
  2. LLM_PROVIDER env var
534
  3. If ANTHROPIC_API_KEY is set → Claude
535
+ 4. If GROQ_API_KEY is set Groq (fast cloud, parallel)
536
+ 5. Default → Ollama (free, local)
537
  """
538
  if provider is None:
539
  provider = os.environ.get("LLM_PROVIDER", "").lower()
540
 
541
  if not provider:
542
+ # Auto-detect: Claude Groq Ollama
543
  if os.environ.get("ANTHROPIC_API_KEY"):
544
  provider = PROVIDER_CLAUDE
545
+ elif os.environ.get("GROQ_API_KEY"):
546
+ provider = PROVIDER_GROQ
547
  else:
548
  provider = PROVIDER_OLLAMA
549
 
550
  if provider == PROVIDER_CLAUDE:
551
  default_model = model or MODEL_HAIKU
552
  return ClaudeClient(default_model=default_model)
553
+ elif provider == PROVIDER_GROQ:
554
+ default_model = model or MODEL_GROQ_LLAMA_8B
555
+ return GroqClient(default_model=default_model)
556
  elif provider == PROVIDER_OLLAMA:
557
  default_model = model or MODEL_LLAMA
558
  return OllamaClient(base_url=ollama_url, default_model=default_model)
559
  else:
560
+ raise ValueError(f"Unknown LLM provider: {provider}. Use 'claude', 'groq', or 'ollama'.")
561
 
562
 
563
  # --- Prompt Templates ---
src/soci/engine/simulation.py CHANGED
@@ -59,6 +59,9 @@ class Simulation:
59
  # Daily routines per agent (rebuilt from persona each day)
60
  self.routines: dict[str, DailyRoutine] = {}
61
  self._last_routine_day: int = -1
 
 
 
62
  # Callback for real-time output
63
  self.on_event: Optional[Callable[[str], None]] = None
64
 
@@ -211,55 +214,73 @@ class Simulation:
211
  routine_actions.append((agent, action))
212
  continue
213
 
214
- # No routine slot — fallback to LLM (rare)
215
- action_coros.append(self._decide_action(agent))
216
- action_agents.append(agent)
 
217
 
218
  # Execute routine-driven actions (no LLM needed)
219
  for agent, action in routine_actions:
220
  await self._execute_action(agent, action)
221
 
222
  # Run LLM action decisions concurrently (only for agents without routine match)
223
- if action_coros:
224
  action_results = await batch_llm_calls(action_coros, self._max_concurrent)
225
  for agent, result in zip(action_agents, action_results):
226
  if result and isinstance(result, AgentAction):
227
  await self._execute_action(agent, result)
228
 
229
- # 6. Handle active conversations
230
- conv_coros = []
231
- for conv_id, conv in list(self.active_conversations.items()):
232
- if conv.is_finished:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  self._finish_conversation(conv)
234
- del self.active_conversations[conv_id]
235
- continue
236
- # Determine who speaks next
237
- last_speaker = conv.turns[-1].speaker_id if conv.turns else None
238
- next_speaker_id = [p for p in conv.participants if p != last_speaker]
239
- if next_speaker_id:
240
- responder = self.agents.get(next_speaker_id[0])
241
- other = self.agents.get(last_speaker) if last_speaker else None
242
- if responder and other:
243
- conv_coros.append(
244
- continue_conversation(conv, responder, other, self.llm, self.clock)
245
- )
246
 
247
- if conv_coros:
248
- await batch_llm_calls(conv_coros, self._max_concurrent)
249
-
250
- # 7. Social: maybe start new conversations
251
- await self._handle_social_interactions(ordered_agents)
252
 
253
  # 8. Reflections for agents with enough accumulated importance
254
- reflect_coros = []
255
- reflect_agents = []
256
- for agent in ordered_agents:
257
- if agent.memory.should_reflect() and not agent.is_player:
258
- reflect_coros.append(self._generate_reflection(agent))
259
- reflect_agents.append(agent)
260
-
261
- if reflect_coros:
262
- await batch_llm_calls(reflect_coros, self._max_concurrent)
 
 
 
 
 
263
 
264
  # 9. Romance — develop attractions and relationships
265
  self._tick_romance()
 
59
  # Daily routines per agent (rebuilt from persona each day)
60
  self.routines: dict[str, DailyRoutine] = {}
61
  self._last_routine_day: int = -1
62
+ # Speed-aware flags (set by server loop for fast-forward)
63
+ self._skip_llm_this_tick: bool = False
64
+ self._max_convos_this_tick: int = 0 # 0 = no limit
65
  # Callback for real-time output
66
  self.on_event: Optional[Callable[[str], None]] = None
67
 
 
214
  routine_actions.append((agent, action))
215
  continue
216
 
217
+ # No routine slot — fallback to LLM (rare), skip in fast-forward
218
+ if not self._skip_llm_this_tick:
219
+ action_coros.append(self._decide_action(agent))
220
+ action_agents.append(agent)
221
 
222
  # Execute routine-driven actions (no LLM needed)
223
  for agent, action in routine_actions:
224
  await self._execute_action(agent, action)
225
 
226
  # Run LLM action decisions concurrently (only for agents without routine match)
227
+ if action_coros and not self._skip_llm_this_tick:
228
  action_results = await batch_llm_calls(action_coros, self._max_concurrent)
229
  for agent, result in zip(action_agents, action_results):
230
  if result and isinstance(result, AgentAction):
231
  await self._execute_action(agent, result)
232
 
233
+ # 6. Handle active conversations (skip in 50x mode)
234
+ if not self._skip_llm_this_tick:
235
+ conv_coros = []
236
+ for conv_id, conv in list(self.active_conversations.items()):
237
+ if conv.is_finished:
238
+ self._finish_conversation(conv)
239
+ del self.active_conversations[conv_id]
240
+ continue
241
+ # Determine who speaks next
242
+ last_speaker = conv.turns[-1].speaker_id if conv.turns else None
243
+ next_speaker_id = [p for p in conv.participants if p != last_speaker]
244
+ if next_speaker_id:
245
+ responder = self.agents.get(next_speaker_id[0])
246
+ other = self.agents.get(last_speaker) if last_speaker else None
247
+ if responder and other:
248
+ conv_coros.append(
249
+ continue_conversation(conv, responder, other, self.llm, self.clock)
250
+ )
251
+
252
+ # Limit conversations at high speed
253
+ if self._max_convos_this_tick > 0 and len(conv_coros) > self._max_convos_this_tick:
254
+ conv_coros = conv_coros[:self._max_convos_this_tick]
255
+
256
+ if conv_coros:
257
+ await batch_llm_calls(conv_coros, self._max_concurrent)
258
+ else:
259
+ # 50x mode: force-finish all active conversations
260
+ for conv_id, conv in list(self.active_conversations.items()):
261
  self._finish_conversation(conv)
262
+ self.active_conversations.clear()
 
 
 
 
 
 
 
 
 
 
 
263
 
264
+ # 7. Social: maybe start new conversations (respect speed limits)
265
+ if not self._skip_llm_this_tick:
266
+ if self._max_convos_this_tick == 0 or len(self.active_conversations) < self._max_convos_this_tick:
267
+ await self._handle_social_interactions(ordered_agents)
 
268
 
269
  # 8. Reflections for agents with enough accumulated importance
270
+ if not self._skip_llm_this_tick:
271
+ reflect_coros = []
272
+ reflect_agents = []
273
+ for agent in ordered_agents:
274
+ if agent.memory.should_reflect() and not agent.is_player:
275
+ reflect_coros.append(self._generate_reflection(agent))
276
+ reflect_agents.append(agent)
277
+
278
+ # At 10x, limit reflections to 1 per tick
279
+ if self._max_convos_this_tick > 0 and len(reflect_coros) > 1:
280
+ reflect_coros = reflect_coros[:1]
281
+
282
+ if reflect_coros:
283
+ await batch_llm_calls(reflect_coros, self._max_concurrent)
284
 
285
  # 9. Romance — develop attractions and relationships
286
  self._tick_romance()