Add Groq LLM provider and fix speed controls for real fast-forward
Browse files- Add GroqClient for fast parallel cloud inference (free tier 30 req/min)
- Auto-detect: Claude -> Groq -> Ollama based on API keys
- Speed controls now actually affect simulation speed:
- 5x: limits to 2 conversations/tick
- 10x: limits to 1 conversation + 1 reflection/tick
- 50x: pure routine mode, zero LLM calls, instant ticks
- Skip sleep delay entirely at high speeds
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- .env.example +5 -1
- main.py +2 -2
- src/soci/api/server.py +24 -1
- src/soci/engine/llm.py +183 -6
- src/soci/engine/simulation.py +55 -34
.env.example
CHANGED
|
@@ -1,9 +1,13 @@
|
|
| 1 |
-
# LLM Provider: "claude" or "ollama" (auto-detects if not set)
|
| 2 |
# LLM_PROVIDER=ollama
|
| 3 |
|
| 4 |
# For Claude (paid API):
|
| 5 |
# ANTHROPIC_API_KEY=sk-ant-api03-your-key-here
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
# For Ollama (free, local):
|
| 8 |
# Install: https://ollama.com
|
| 9 |
# Then: ollama pull llama3.1
|
|
|
|
| 1 |
+
# LLM Provider: "claude", "groq", or "ollama" (auto-detects if not set)
|
| 2 |
# LLM_PROVIDER=ollama
|
| 3 |
|
| 4 |
# For Claude (paid API):
|
| 5 |
# ANTHROPIC_API_KEY=sk-ant-api03-your-key-here
|
| 6 |
|
| 7 |
+
# For Groq (fast cloud, free tier 30 req/min):
|
| 8 |
+
# Sign up: https://console.groq.com
|
| 9 |
+
# GROQ_API_KEY=gsk_your-key-here
|
| 10 |
+
|
| 11 |
# For Ollama (free, local):
|
| 12 |
# Install: https://ollama.com
|
| 13 |
# Then: ollama pull llama3.1
|
main.py
CHANGED
|
@@ -231,8 +231,8 @@ def main():
|
|
| 231 |
parser.add_argument("--resume", action="store_true", help="Resume from last save")
|
| 232 |
parser.add_argument("--generate", action="store_true",
|
| 233 |
help="Generate procedural agents to fill up to --agents count")
|
| 234 |
-
parser.add_argument("--provider", type=str, default="", choices=["", "claude", "ollama"],
|
| 235 |
-
help="LLM provider: claude or ollama (default: auto-detect)")
|
| 236 |
parser.add_argument("--model", type=str, default="",
|
| 237 |
help="Model name (e.g. llama3.1:8b, mistral, qwen2.5)")
|
| 238 |
args = parser.parse_args()
|
|
|
|
| 231 |
parser.add_argument("--resume", action="store_true", help="Resume from last save")
|
| 232 |
parser.add_argument("--generate", action="store_true",
|
| 233 |
help="Generate procedural agents to fill up to --agents count")
|
| 234 |
+
parser.add_argument("--provider", type=str, default="", choices=["", "claude", "groq", "ollama"],
|
| 235 |
+
help="LLM provider: claude, groq, or ollama (default: auto-detect)")
|
| 236 |
parser.add_argument("--model", type=str, default="",
|
| 237 |
help="Model name (e.g. llama3.1:8b, mistral, qwen2.5)")
|
| 238 |
args = parser.parse_args()
|
src/soci/api/server.py
CHANGED
|
@@ -50,11 +50,34 @@ async def simulation_loop(sim: Simulation, db: Database, tick_delay: float = 2.0
|
|
| 50 |
if _sim_paused:
|
| 51 |
await asyncio.sleep(0.5)
|
| 52 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
await sim.tick()
|
|
|
|
| 54 |
# Auto-save every 24 ticks
|
| 55 |
if sim.clock.total_ticks % 24 == 0:
|
| 56 |
await save_simulation(sim, db, "autosave")
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
except asyncio.CancelledError:
|
| 59 |
logger.info("Simulation loop cancelled")
|
| 60 |
await save_simulation(sim, db, "autosave")
|
|
|
|
| 50 |
if _sim_paused:
|
| 51 |
await asyncio.sleep(0.5)
|
| 52 |
continue
|
| 53 |
+
|
| 54 |
+
# At high speeds, limit LLM calls to keep ticks fast
|
| 55 |
+
# _sim_speed < 0.2 means 5x+, so cap concurrent conversations
|
| 56 |
+
if _sim_speed <= 0.05:
|
| 57 |
+
# 50x: skip LLM entirely, pure routine mode
|
| 58 |
+
sim._skip_llm_this_tick = True
|
| 59 |
+
elif _sim_speed <= 0.15:
|
| 60 |
+
# 10x: max 1 conversation per tick
|
| 61 |
+
sim._max_convos_this_tick = 1
|
| 62 |
+
elif _sim_speed <= 0.35:
|
| 63 |
+
# 5x: max 2 conversations per tick
|
| 64 |
+
sim._max_convos_this_tick = 2
|
| 65 |
+
else:
|
| 66 |
+
sim._skip_llm_this_tick = False
|
| 67 |
+
sim._max_convos_this_tick = 0 # 0 = no limit
|
| 68 |
+
|
| 69 |
await sim.tick()
|
| 70 |
+
|
| 71 |
# Auto-save every 24 ticks
|
| 72 |
if sim.clock.total_ticks % 24 == 0:
|
| 73 |
await save_simulation(sim, db, "autosave")
|
| 74 |
+
|
| 75 |
+
# At high speeds, skip the delay entirely
|
| 76 |
+
delay = tick_delay * _sim_speed
|
| 77 |
+
if delay > 0.05:
|
| 78 |
+
await asyncio.sleep(delay)
|
| 79 |
+
else:
|
| 80 |
+
await asyncio.sleep(0) # Yield to event loop
|
| 81 |
except asyncio.CancelledError:
|
| 82 |
logger.info("Simulation loop cancelled")
|
| 83 |
await save_simulation(sim, db, "autosave")
|
src/soci/engine/llm.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""LLM client — supports Claude API and Ollama (local LLMs) with model routing and cost tracking."""
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
@@ -17,6 +17,7 @@ logger = logging.getLogger(__name__)
|
|
| 17 |
# --- Provider constants ---
|
| 18 |
PROVIDER_CLAUDE = "claude"
|
| 19 |
PROVIDER_OLLAMA = "ollama"
|
|
|
|
| 20 |
|
| 21 |
# Claude model IDs
|
| 22 |
MODEL_SONNET = "claude-sonnet-4-5-20250929"
|
|
@@ -29,10 +30,18 @@ MODEL_MISTRAL = "mistral"
|
|
| 29 |
MODEL_QWEN = "qwen2.5"
|
| 30 |
MODEL_GEMMA = "gemma2"
|
| 31 |
|
| 32 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
COST_PER_1M = {
|
| 34 |
MODEL_SONNET: {"input": 3.0, "output": 15.0},
|
| 35 |
MODEL_HAIKU: {"input": 0.80, "output": 4.0},
|
|
|
|
|
|
|
|
|
|
| 36 |
}
|
| 37 |
|
| 38 |
|
|
@@ -346,6 +355,168 @@ class OllamaClient:
|
|
| 346 |
return mapping.get(model, model)
|
| 347 |
|
| 348 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
# ============================================================
|
| 350 |
# Factory — create the right client based on config
|
| 351 |
# ============================================================
|
|
@@ -354,33 +525,39 @@ def create_llm_client(
|
|
| 354 |
provider: Optional[str] = None,
|
| 355 |
model: Optional[str] = None,
|
| 356 |
ollama_url: str = "http://localhost:11434",
|
| 357 |
-
) -> ClaudeClient | OllamaClient:
|
| 358 |
"""Create an LLM client based on environment or explicit config.
|
| 359 |
|
| 360 |
Provider detection order:
|
| 361 |
1. Explicit provider argument
|
| 362 |
2. LLM_PROVIDER env var
|
| 363 |
3. If ANTHROPIC_API_KEY is set → Claude
|
| 364 |
-
4.
|
|
|
|
| 365 |
"""
|
| 366 |
if provider is None:
|
| 367 |
provider = os.environ.get("LLM_PROVIDER", "").lower()
|
| 368 |
|
| 369 |
if not provider:
|
| 370 |
-
# Auto-detect:
|
| 371 |
if os.environ.get("ANTHROPIC_API_KEY"):
|
| 372 |
provider = PROVIDER_CLAUDE
|
|
|
|
|
|
|
| 373 |
else:
|
| 374 |
provider = PROVIDER_OLLAMA
|
| 375 |
|
| 376 |
if provider == PROVIDER_CLAUDE:
|
| 377 |
default_model = model or MODEL_HAIKU
|
| 378 |
return ClaudeClient(default_model=default_model)
|
|
|
|
|
|
|
|
|
|
| 379 |
elif provider == PROVIDER_OLLAMA:
|
| 380 |
default_model = model or MODEL_LLAMA
|
| 381 |
return OllamaClient(base_url=ollama_url, default_model=default_model)
|
| 382 |
else:
|
| 383 |
-
raise ValueError(f"Unknown LLM provider: {provider}. Use 'claude' or 'ollama'.")
|
| 384 |
|
| 385 |
|
| 386 |
# --- Prompt Templates ---
|
|
|
|
| 1 |
+
"""LLM client — supports Claude API, Groq, and Ollama (local LLMs) with model routing and cost tracking."""
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
|
|
| 17 |
# --- Provider constants ---
|
| 18 |
PROVIDER_CLAUDE = "claude"
|
| 19 |
PROVIDER_OLLAMA = "ollama"
|
| 20 |
+
PROVIDER_GROQ = "groq"
|
| 21 |
|
| 22 |
# Claude model IDs
|
| 23 |
MODEL_SONNET = "claude-sonnet-4-5-20250929"
|
|
|
|
| 30 |
MODEL_QWEN = "qwen2.5"
|
| 31 |
MODEL_GEMMA = "gemma2"
|
| 32 |
|
| 33 |
+
# Groq model IDs (fast cloud inference)
|
| 34 |
+
MODEL_GROQ_LLAMA_8B = "llama-3.1-8b-instant"
|
| 35 |
+
MODEL_GROQ_LLAMA_70B = "llama-3.3-70b-versatile"
|
| 36 |
+
MODEL_GROQ_MIXTRAL = "mixtral-8x7b-32768"
|
| 37 |
+
|
| 38 |
+
# Approximate cost per 1M tokens (USD) — Ollama is free, Groq is very cheap
|
| 39 |
COST_PER_1M = {
|
| 40 |
MODEL_SONNET: {"input": 3.0, "output": 15.0},
|
| 41 |
MODEL_HAIKU: {"input": 0.80, "output": 4.0},
|
| 42 |
+
MODEL_GROQ_LLAMA_8B: {"input": 0.05, "output": 0.08},
|
| 43 |
+
MODEL_GROQ_LLAMA_70B: {"input": 0.59, "output": 0.79},
|
| 44 |
+
MODEL_GROQ_MIXTRAL: {"input": 0.24, "output": 0.24},
|
| 45 |
}
|
| 46 |
|
| 47 |
|
|
|
|
| 355 |
return mapping.get(model, model)
|
| 356 |
|
| 357 |
|
| 358 |
+
# ============================================================
|
| 359 |
+
# Groq (Fast Cloud Inference) Client
|
| 360 |
+
# ============================================================
|
| 361 |
+
|
| 362 |
+
class GroqClient:
|
| 363 |
+
"""Wrapper around the Groq API for fast cloud inference.
|
| 364 |
+
|
| 365 |
+
Groq provides extremely fast inference (~500 tok/s) with parallel request support.
|
| 366 |
+
Free tier: 30 requests/min on llama-3.1-8b-instant.
|
| 367 |
+
Sign up: https://console.groq.com
|
| 368 |
+
"""
|
| 369 |
+
|
| 370 |
+
def __init__(
|
| 371 |
+
self,
|
| 372 |
+
api_key: Optional[str] = None,
|
| 373 |
+
default_model: str = MODEL_GROQ_LLAMA_8B,
|
| 374 |
+
max_retries: int = 3,
|
| 375 |
+
) -> None:
|
| 376 |
+
self.api_key = api_key or os.environ.get("GROQ_API_KEY", "")
|
| 377 |
+
if not self.api_key:
|
| 378 |
+
raise ValueError(
|
| 379 |
+
"GROQ_API_KEY not set. Get a free key at https://console.groq.com"
|
| 380 |
+
)
|
| 381 |
+
self.default_model = default_model
|
| 382 |
+
self.max_retries = max_retries
|
| 383 |
+
self.usage = LLMUsage()
|
| 384 |
+
self.provider = PROVIDER_GROQ
|
| 385 |
+
self._http = httpx.AsyncClient(
|
| 386 |
+
base_url="https://api.groq.com/openai/v1",
|
| 387 |
+
headers={
|
| 388 |
+
"Authorization": f"Bearer {self.api_key}",
|
| 389 |
+
"Content-Type": "application/json",
|
| 390 |
+
},
|
| 391 |
+
timeout=60.0,
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
async def complete(
|
| 395 |
+
self,
|
| 396 |
+
system: str,
|
| 397 |
+
user_message: str,
|
| 398 |
+
model: Optional[str] = None,
|
| 399 |
+
temperature: float = 0.7,
|
| 400 |
+
max_tokens: int = 1024,
|
| 401 |
+
) -> str:
|
| 402 |
+
"""Send a chat completion request to Groq (async, parallel-safe)."""
|
| 403 |
+
model = self._map_model(model or self.default_model)
|
| 404 |
+
|
| 405 |
+
payload = {
|
| 406 |
+
"model": model,
|
| 407 |
+
"messages": [
|
| 408 |
+
{"role": "system", "content": system},
|
| 409 |
+
{"role": "user", "content": user_message},
|
| 410 |
+
],
|
| 411 |
+
"temperature": temperature,
|
| 412 |
+
"max_tokens": max_tokens,
|
| 413 |
+
}
|
| 414 |
+
|
| 415 |
+
for attempt in range(self.max_retries):
|
| 416 |
+
try:
|
| 417 |
+
response = await self._http.post("/chat/completions", json=payload)
|
| 418 |
+
response.raise_for_status()
|
| 419 |
+
data = response.json()
|
| 420 |
+
|
| 421 |
+
usage = data.get("usage", {})
|
| 422 |
+
self.usage.record(
|
| 423 |
+
model,
|
| 424 |
+
usage.get("prompt_tokens", 0),
|
| 425 |
+
usage.get("completion_tokens", 0),
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
return data["choices"][0]["message"]["content"]
|
| 429 |
+
|
| 430 |
+
except httpx.HTTPStatusError as e:
|
| 431 |
+
if e.response.status_code == 429:
|
| 432 |
+
# Rate limited — wait and retry
|
| 433 |
+
wait = 2 ** attempt + 1
|
| 434 |
+
logger.warning(f"Groq rate limited, waiting {wait}s (attempt {attempt + 1})")
|
| 435 |
+
await asyncio.sleep(wait)
|
| 436 |
+
elif e.response.status_code == 401:
|
| 437 |
+
raise ValueError("Invalid GROQ_API_KEY")
|
| 438 |
+
else:
|
| 439 |
+
logger.error(f"Groq API error: {e.response.status_code} {e.response.text[:200]}")
|
| 440 |
+
if attempt == self.max_retries - 1:
|
| 441 |
+
raise
|
| 442 |
+
await asyncio.sleep(1)
|
| 443 |
+
except Exception as e:
|
| 444 |
+
logger.error(f"Groq error: {e}")
|
| 445 |
+
if attempt == self.max_retries - 1:
|
| 446 |
+
raise
|
| 447 |
+
await asyncio.sleep(1)
|
| 448 |
+
return ""
|
| 449 |
+
|
| 450 |
+
async def complete_json(
|
| 451 |
+
self,
|
| 452 |
+
system: str,
|
| 453 |
+
user_message: str,
|
| 454 |
+
model: Optional[str] = None,
|
| 455 |
+
temperature: float = 0.7,
|
| 456 |
+
max_tokens: int = 1024,
|
| 457 |
+
) -> dict:
|
| 458 |
+
"""Send a JSON-mode request to Groq."""
|
| 459 |
+
model = self._map_model(model or self.default_model)
|
| 460 |
+
|
| 461 |
+
json_instruction = (
|
| 462 |
+
"\n\nRespond ONLY with valid JSON. No markdown, no explanation, no extra text. "
|
| 463 |
+
"Just the JSON object."
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
payload = {
|
| 467 |
+
"model": model,
|
| 468 |
+
"messages": [
|
| 469 |
+
{"role": "system", "content": system},
|
| 470 |
+
{"role": "user", "content": user_message + json_instruction},
|
| 471 |
+
],
|
| 472 |
+
"temperature": temperature,
|
| 473 |
+
"max_tokens": max_tokens,
|
| 474 |
+
"response_format": {"type": "json_object"},
|
| 475 |
+
}
|
| 476 |
+
|
| 477 |
+
for attempt in range(self.max_retries):
|
| 478 |
+
try:
|
| 479 |
+
response = await self._http.post("/chat/completions", json=payload)
|
| 480 |
+
response.raise_for_status()
|
| 481 |
+
data = response.json()
|
| 482 |
+
|
| 483 |
+
usage = data.get("usage", {})
|
| 484 |
+
self.usage.record(
|
| 485 |
+
model,
|
| 486 |
+
usage.get("prompt_tokens", 0),
|
| 487 |
+
usage.get("completion_tokens", 0),
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
text = data["choices"][0]["message"]["content"]
|
| 491 |
+
return _parse_json_response(text)
|
| 492 |
+
|
| 493 |
+
except httpx.HTTPStatusError as e:
|
| 494 |
+
if e.response.status_code == 429:
|
| 495 |
+
wait = 2 ** attempt + 1
|
| 496 |
+
logger.warning(f"Groq rate limited, waiting {wait}s")
|
| 497 |
+
await asyncio.sleep(wait)
|
| 498 |
+
else:
|
| 499 |
+
logger.error(f"Groq JSON error: {e.response.status_code}")
|
| 500 |
+
if attempt == self.max_retries - 1:
|
| 501 |
+
return {}
|
| 502 |
+
await asyncio.sleep(1)
|
| 503 |
+
except Exception as e:
|
| 504 |
+
logger.error(f"Groq JSON error: {e}")
|
| 505 |
+
if attempt == self.max_retries - 1:
|
| 506 |
+
return {}
|
| 507 |
+
await asyncio.sleep(1)
|
| 508 |
+
return {}
|
| 509 |
+
|
| 510 |
+
def _map_model(self, model: str) -> str:
|
| 511 |
+
"""Map Claude/Ollama model names to Groq equivalents."""
|
| 512 |
+
mapping = {
|
| 513 |
+
MODEL_SONNET: MODEL_GROQ_LLAMA_70B, # Use 70B for "smart" model
|
| 514 |
+
MODEL_HAIKU: self.default_model, # Use default (8B) for routine
|
| 515 |
+
MODEL_LLAMA: MODEL_GROQ_LLAMA_8B,
|
| 516 |
+
}
|
| 517 |
+
return mapping.get(model, model)
|
| 518 |
+
|
| 519 |
+
|
| 520 |
# ============================================================
|
| 521 |
# Factory — create the right client based on config
|
| 522 |
# ============================================================
|
|
|
|
| 525 |
provider: Optional[str] = None,
|
| 526 |
model: Optional[str] = None,
|
| 527 |
ollama_url: str = "http://localhost:11434",
|
| 528 |
+
) -> ClaudeClient | OllamaClient | GroqClient:
|
| 529 |
"""Create an LLM client based on environment or explicit config.
|
| 530 |
|
| 531 |
Provider detection order:
|
| 532 |
1. Explicit provider argument
|
| 533 |
2. LLM_PROVIDER env var
|
| 534 |
3. If ANTHROPIC_API_KEY is set → Claude
|
| 535 |
+
4. If GROQ_API_KEY is set → Groq (fast cloud, parallel)
|
| 536 |
+
5. Default → Ollama (free, local)
|
| 537 |
"""
|
| 538 |
if provider is None:
|
| 539 |
provider = os.environ.get("LLM_PROVIDER", "").lower()
|
| 540 |
|
| 541 |
if not provider:
|
| 542 |
+
# Auto-detect: Claude → Groq → Ollama
|
| 543 |
if os.environ.get("ANTHROPIC_API_KEY"):
|
| 544 |
provider = PROVIDER_CLAUDE
|
| 545 |
+
elif os.environ.get("GROQ_API_KEY"):
|
| 546 |
+
provider = PROVIDER_GROQ
|
| 547 |
else:
|
| 548 |
provider = PROVIDER_OLLAMA
|
| 549 |
|
| 550 |
if provider == PROVIDER_CLAUDE:
|
| 551 |
default_model = model or MODEL_HAIKU
|
| 552 |
return ClaudeClient(default_model=default_model)
|
| 553 |
+
elif provider == PROVIDER_GROQ:
|
| 554 |
+
default_model = model or MODEL_GROQ_LLAMA_8B
|
| 555 |
+
return GroqClient(default_model=default_model)
|
| 556 |
elif provider == PROVIDER_OLLAMA:
|
| 557 |
default_model = model or MODEL_LLAMA
|
| 558 |
return OllamaClient(base_url=ollama_url, default_model=default_model)
|
| 559 |
else:
|
| 560 |
+
raise ValueError(f"Unknown LLM provider: {provider}. Use 'claude', 'groq', or 'ollama'.")
|
| 561 |
|
| 562 |
|
| 563 |
# --- Prompt Templates ---
|
src/soci/engine/simulation.py
CHANGED
|
@@ -59,6 +59,9 @@ class Simulation:
|
|
| 59 |
# Daily routines per agent (rebuilt from persona each day)
|
| 60 |
self.routines: dict[str, DailyRoutine] = {}
|
| 61 |
self._last_routine_day: int = -1
|
|
|
|
|
|
|
|
|
|
| 62 |
# Callback for real-time output
|
| 63 |
self.on_event: Optional[Callable[[str], None]] = None
|
| 64 |
|
|
@@ -211,55 +214,73 @@ class Simulation:
|
|
| 211 |
routine_actions.append((agent, action))
|
| 212 |
continue
|
| 213 |
|
| 214 |
-
# No routine slot — fallback to LLM (rare)
|
| 215 |
-
|
| 216 |
-
|
|
|
|
| 217 |
|
| 218 |
# Execute routine-driven actions (no LLM needed)
|
| 219 |
for agent, action in routine_actions:
|
| 220 |
await self._execute_action(agent, action)
|
| 221 |
|
| 222 |
# Run LLM action decisions concurrently (only for agents without routine match)
|
| 223 |
-
if action_coros:
|
| 224 |
action_results = await batch_llm_calls(action_coros, self._max_concurrent)
|
| 225 |
for agent, result in zip(action_agents, action_results):
|
| 226 |
if result and isinstance(result, AgentAction):
|
| 227 |
await self._execute_action(agent, result)
|
| 228 |
|
| 229 |
-
# 6. Handle active conversations
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
self._finish_conversation(conv)
|
| 234 |
-
|
| 235 |
-
continue
|
| 236 |
-
# Determine who speaks next
|
| 237 |
-
last_speaker = conv.turns[-1].speaker_id if conv.turns else None
|
| 238 |
-
next_speaker_id = [p for p in conv.participants if p != last_speaker]
|
| 239 |
-
if next_speaker_id:
|
| 240 |
-
responder = self.agents.get(next_speaker_id[0])
|
| 241 |
-
other = self.agents.get(last_speaker) if last_speaker else None
|
| 242 |
-
if responder and other:
|
| 243 |
-
conv_coros.append(
|
| 244 |
-
continue_conversation(conv, responder, other, self.llm, self.clock)
|
| 245 |
-
)
|
| 246 |
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
await self._handle_social_interactions(ordered_agents)
|
| 252 |
|
| 253 |
# 8. Reflections for agents with enough accumulated importance
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
|
| 264 |
# 9. Romance — develop attractions and relationships
|
| 265 |
self._tick_romance()
|
|
|
|
| 59 |
# Daily routines per agent (rebuilt from persona each day)
|
| 60 |
self.routines: dict[str, DailyRoutine] = {}
|
| 61 |
self._last_routine_day: int = -1
|
| 62 |
+
# Speed-aware flags (set by server loop for fast-forward)
|
| 63 |
+
self._skip_llm_this_tick: bool = False
|
| 64 |
+
self._max_convos_this_tick: int = 0 # 0 = no limit
|
| 65 |
# Callback for real-time output
|
| 66 |
self.on_event: Optional[Callable[[str], None]] = None
|
| 67 |
|
|
|
|
| 214 |
routine_actions.append((agent, action))
|
| 215 |
continue
|
| 216 |
|
| 217 |
+
# No routine slot — fallback to LLM (rare), skip in fast-forward
|
| 218 |
+
if not self._skip_llm_this_tick:
|
| 219 |
+
action_coros.append(self._decide_action(agent))
|
| 220 |
+
action_agents.append(agent)
|
| 221 |
|
| 222 |
# Execute routine-driven actions (no LLM needed)
|
| 223 |
for agent, action in routine_actions:
|
| 224 |
await self._execute_action(agent, action)
|
| 225 |
|
| 226 |
# Run LLM action decisions concurrently (only for agents without routine match)
|
| 227 |
+
if action_coros and not self._skip_llm_this_tick:
|
| 228 |
action_results = await batch_llm_calls(action_coros, self._max_concurrent)
|
| 229 |
for agent, result in zip(action_agents, action_results):
|
| 230 |
if result and isinstance(result, AgentAction):
|
| 231 |
await self._execute_action(agent, result)
|
| 232 |
|
| 233 |
+
# 6. Handle active conversations (skip in 50x mode)
|
| 234 |
+
if not self._skip_llm_this_tick:
|
| 235 |
+
conv_coros = []
|
| 236 |
+
for conv_id, conv in list(self.active_conversations.items()):
|
| 237 |
+
if conv.is_finished:
|
| 238 |
+
self._finish_conversation(conv)
|
| 239 |
+
del self.active_conversations[conv_id]
|
| 240 |
+
continue
|
| 241 |
+
# Determine who speaks next
|
| 242 |
+
last_speaker = conv.turns[-1].speaker_id if conv.turns else None
|
| 243 |
+
next_speaker_id = [p for p in conv.participants if p != last_speaker]
|
| 244 |
+
if next_speaker_id:
|
| 245 |
+
responder = self.agents.get(next_speaker_id[0])
|
| 246 |
+
other = self.agents.get(last_speaker) if last_speaker else None
|
| 247 |
+
if responder and other:
|
| 248 |
+
conv_coros.append(
|
| 249 |
+
continue_conversation(conv, responder, other, self.llm, self.clock)
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
# Limit conversations at high speed
|
| 253 |
+
if self._max_convos_this_tick > 0 and len(conv_coros) > self._max_convos_this_tick:
|
| 254 |
+
conv_coros = conv_coros[:self._max_convos_this_tick]
|
| 255 |
+
|
| 256 |
+
if conv_coros:
|
| 257 |
+
await batch_llm_calls(conv_coros, self._max_concurrent)
|
| 258 |
+
else:
|
| 259 |
+
# 50x mode: force-finish all active conversations
|
| 260 |
+
for conv_id, conv in list(self.active_conversations.items()):
|
| 261 |
self._finish_conversation(conv)
|
| 262 |
+
self.active_conversations.clear()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
|
| 264 |
+
# 7. Social: maybe start new conversations (respect speed limits)
|
| 265 |
+
if not self._skip_llm_this_tick:
|
| 266 |
+
if self._max_convos_this_tick == 0 or len(self.active_conversations) < self._max_convos_this_tick:
|
| 267 |
+
await self._handle_social_interactions(ordered_agents)
|
|
|
|
| 268 |
|
| 269 |
# 8. Reflections for agents with enough accumulated importance
|
| 270 |
+
if not self._skip_llm_this_tick:
|
| 271 |
+
reflect_coros = []
|
| 272 |
+
reflect_agents = []
|
| 273 |
+
for agent in ordered_agents:
|
| 274 |
+
if agent.memory.should_reflect() and not agent.is_player:
|
| 275 |
+
reflect_coros.append(self._generate_reflection(agent))
|
| 276 |
+
reflect_agents.append(agent)
|
| 277 |
+
|
| 278 |
+
# At 10x, limit reflections to 1 per tick
|
| 279 |
+
if self._max_convos_this_tick > 0 and len(reflect_coros) > 1:
|
| 280 |
+
reflect_coros = reflect_coros[:1]
|
| 281 |
+
|
| 282 |
+
if reflect_coros:
|
| 283 |
+
await batch_llm_calls(reflect_coros, self._max_concurrent)
|
| 284 |
|
| 285 |
# 9. Romance — develop attractions and relationships
|
| 286 |
self._tick_romance()
|