RayMelius Claude Sonnet 4.6 commited on
Commit
adb6d19
Β·
1 Parent(s): 29d9da4

Add Gemini LLM support, fix back-view direction, scripted conversation fallbacks

Browse files

- Add GeminiClient using OpenAI-compatible AI Studio endpoint (free tier:
15 RPM / 1M tokens/day on gemini-2.0-flash, set GEMINI_API_KEY to use)
- Auto-detect provider order: Claude β†’ Groq β†’ Gemini β†’ Ollama
- Fix agent back-view: move direction tracking from drawPerson() to animate()
using pixel delta β€” reliable for all path angles, no waypoint guessing
- Add scripted fallback dialogue so conversations animate in the UI even
when no LLM is configured

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

src/soci/actions/conversation.py CHANGED
@@ -3,6 +3,7 @@
3
  from __future__ import annotations
4
 
5
  import logging
 
6
  from dataclasses import dataclass, field
7
  from typing import Optional, TYPE_CHECKING
8
 
@@ -114,9 +115,17 @@ async def initiate_conversation(
114
  max_tokens=512,
115
  )
116
 
117
- # LLM unavailable β€” skip conversation entirely
118
  if not result:
119
- return None
 
 
 
 
 
 
 
 
120
 
121
  message = result.get("message", f"Hey, {target.name}.")
122
  topic = result.get("topic", "small talk")
@@ -187,10 +196,17 @@ async def continue_conversation(
187
  max_tokens=512,
188
  )
189
 
190
- # LLM unavailable (rate-limited / circuit breaker) β€” end conversation cleanly
191
  if not result:
192
- conversation.is_active = False
193
- return last_turn
 
 
 
 
 
 
 
194
 
195
  message = result.get("message", "Hmm, interesting.")
196
 
 
3
  from __future__ import annotations
4
 
5
  import logging
6
+ import random
7
  from dataclasses import dataclass, field
8
  from typing import Optional, TYPE_CHECKING
9
 
 
115
  max_tokens=512,
116
  )
117
 
118
+ # LLM unavailable β€” use scripted fallback so conversations still animate in the UI
119
  if not result:
120
+ starters = [
121
+ {"message": f"Hey {target.name}, how's it going?", "topic": "greeting", "inner_thought": "Making small talk."},
122
+ {"message": f"Oh, {target.name}! Didn't expect to run into you here.", "topic": "chance meeting", "inner_thought": "Good to see a familiar face."},
123
+ {"message": "What have you been up to lately?", "topic": "small talk", "inner_thought": "Curious about their day."},
124
+ {"message": "Lovely weather today, isn't it?", "topic": "weather", "inner_thought": "Breaking the ice."},
125
+ {"message": f"Hi {target.name}! Have you heard any news lately?", "topic": "news", "inner_thought": "Looking for something to talk about."},
126
+ {"message": "I was just thinking about grabbing something to eat. You?", "topic": "food", "inner_thought": "Maybe we can go together."},
127
+ ]
128
+ result = random.choice(starters)
129
 
130
  message = result.get("message", f"Hey, {target.name}.")
131
  topic = result.get("topic", "small talk")
 
196
  max_tokens=512,
197
  )
198
 
199
+ # LLM unavailable β€” scripted response keeps conversation alive in the UI
200
  if not result:
201
+ replies = [
202
+ {"message": "Ha, yeah, I was just thinking the same thing!", "inner_thought": "Go with the flow.", "sentiment_delta": 0.05, "trust_delta": 0.02},
203
+ {"message": "Not too bad, honestly. Just keeping busy.", "inner_thought": "Keep it light.", "sentiment_delta": 0.03, "trust_delta": 0.01},
204
+ {"message": "Interesting! Tell me more.", "inner_thought": "Show some curiosity.", "sentiment_delta": 0.04, "trust_delta": 0.02},
205
+ {"message": "Yeah, it's been that kind of day.", "inner_thought": "Relate to them.", "sentiment_delta": 0.02, "trust_delta": 0.01},
206
+ {"message": "I hear you. Things have been a bit hectic on my end too.", "inner_thought": "Empathize.", "sentiment_delta": 0.04, "trust_delta": 0.03},
207
+ {"message": "Good point. I hadn't thought of it that way.", "inner_thought": "Give them credit.", "sentiment_delta": 0.05, "trust_delta": 0.03},
208
+ ]
209
+ result = random.choice(replies)
210
 
211
  message = result.get("message", "Hmm, interesting.")
212
 
src/soci/engine/llm.py CHANGED
@@ -18,6 +18,7 @@ logger = logging.getLogger(__name__)
18
  PROVIDER_CLAUDE = "claude"
19
  PROVIDER_OLLAMA = "ollama"
20
  PROVIDER_GROQ = "groq"
 
21
 
22
  # Claude model IDs
23
  MODEL_SONNET = "claude-sonnet-4-5-20250929"
@@ -35,6 +36,10 @@ MODEL_GROQ_LLAMA_8B = "llama-3.1-8b-instant"
35
  MODEL_GROQ_LLAMA_70B = "llama-3.3-70b-versatile"
36
  MODEL_GROQ_MIXTRAL = "mixtral-8x7b-32768"
37
 
 
 
 
 
38
  # Approximate cost per 1M tokens (USD) β€” Ollama is free, Groq is very cheap
39
  COST_PER_1M = {
40
  MODEL_SONNET: {"input": 3.0, "output": 15.0},
@@ -583,6 +588,192 @@ class GroqClient:
583
  return mapping.get(model, model)
584
 
585
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
586
  # ============================================================
587
  # Factory β€” create the right client based on config
588
  # ============================================================
@@ -605,11 +796,13 @@ def create_llm_client(
605
  provider = os.environ.get("LLM_PROVIDER", "").lower()
606
 
607
  if not provider:
608
- # Auto-detect: Claude β†’ Groq β†’ Ollama
609
  if os.environ.get("ANTHROPIC_API_KEY"):
610
  provider = PROVIDER_CLAUDE
611
  elif os.environ.get("GROQ_API_KEY"):
612
  provider = PROVIDER_GROQ
 
 
613
  else:
614
  provider = PROVIDER_OLLAMA
615
 
@@ -619,11 +812,14 @@ def create_llm_client(
619
  elif provider == PROVIDER_GROQ:
620
  default_model = model or os.environ.get("GROQ_MODEL", MODEL_GROQ_LLAMA_8B)
621
  return GroqClient(default_model=default_model)
 
 
 
622
  elif provider == PROVIDER_OLLAMA:
623
  default_model = model or os.environ.get("OLLAMA_MODEL", MODEL_LLAMA)
624
  return OllamaClient(base_url=ollama_url, default_model=default_model)
625
  else:
626
- raise ValueError(f"Unknown LLM provider: {provider}. Use 'claude', 'groq', or 'ollama'.")
627
 
628
 
629
  # --- Prompt Templates ---
 
18
  PROVIDER_CLAUDE = "claude"
19
  PROVIDER_OLLAMA = "ollama"
20
  PROVIDER_GROQ = "groq"
21
+ PROVIDER_GEMINI = "gemini"
22
 
23
  # Claude model IDs
24
  MODEL_SONNET = "claude-sonnet-4-5-20250929"
 
36
  MODEL_GROQ_LLAMA_70B = "llama-3.3-70b-versatile"
37
  MODEL_GROQ_MIXTRAL = "mixtral-8x7b-32768"
38
 
39
+ # Google Gemini model IDs (free tier via AI Studio)
40
+ MODEL_GEMINI_FLASH = "gemini-2.0-flash"
41
+ MODEL_GEMINI_PRO = "gemini-1.5-pro"
42
+
43
  # Approximate cost per 1M tokens (USD) β€” Ollama is free, Groq is very cheap
44
  COST_PER_1M = {
45
  MODEL_SONNET: {"input": 3.0, "output": 15.0},
 
588
  return mapping.get(model, model)
589
 
590
 
591
+ # ============================================================
592
+ # Google Gemini Client (free tier via OpenAI-compatible endpoint)
593
+ # ============================================================
594
+
595
+ class GeminiClient:
596
+ """Google Gemini via the OpenAI-compatible AI Studio endpoint.
597
+
598
+ Free tier (no credit card):
599
+ - gemini-2.0-flash: 15 RPM, 1 M tokens/day β€” plenty for a simulation.
600
+ - Get a free key at https://aistudio.google.com/apikey
601
+ Uses the OpenAI-compatible endpoint so no extra SDK is needed.
602
+ """
603
+
604
+ def __init__(
605
+ self,
606
+ api_key: Optional[str] = None,
607
+ default_model: str = MODEL_GEMINI_FLASH,
608
+ max_retries: int = 3,
609
+ max_rpm: int = 14, # stay under the 15 RPM free-tier limit
610
+ ) -> None:
611
+ self.api_key = api_key or os.environ.get("GEMINI_API_KEY", "")
612
+ if not self.api_key:
613
+ raise ValueError(
614
+ "GEMINI_API_KEY not set. "
615
+ "Get a free key at https://aistudio.google.com/apikey"
616
+ )
617
+ self.default_model = default_model
618
+ self.max_retries = max_retries
619
+ self.usage = LLMUsage()
620
+ self.provider = PROVIDER_GEMINI
621
+ self._http = httpx.AsyncClient(
622
+ base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
623
+ headers={
624
+ "Authorization": f"Bearer {self.api_key}",
625
+ "Content-Type": "application/json",
626
+ },
627
+ timeout=60.0,
628
+ )
629
+ self._min_request_interval = 60.0 / max_rpm
630
+ self._last_request_time: float = 0.0
631
+ self._rate_lock = asyncio.Lock()
632
+ self._rate_limited_until: float = 0.0
633
+
634
+ def _is_quota_exhausted(self) -> bool:
635
+ return time.monotonic() < self._rate_limited_until
636
+
637
+ async def _wait_for_rate_limit(self) -> None:
638
+ async with self._rate_lock:
639
+ now = time.monotonic()
640
+ elapsed = now - self._last_request_time
641
+ if elapsed < self._min_request_interval:
642
+ await asyncio.sleep(self._min_request_interval - elapsed)
643
+ self._last_request_time = time.monotonic()
644
+
645
+ def _map_model(self, model: str) -> str:
646
+ """Map Claude/Groq model names to Gemini equivalents."""
647
+ mapping = {
648
+ MODEL_SONNET: self.default_model,
649
+ MODEL_HAIKU: self.default_model,
650
+ MODEL_GROQ_LLAMA_8B: MODEL_GEMINI_FLASH,
651
+ }
652
+ return mapping.get(model, model)
653
+
654
+ async def complete(
655
+ self,
656
+ system: str,
657
+ user_message: str,
658
+ model: Optional[str] = None,
659
+ temperature: float = 0.7,
660
+ max_tokens: int = 1024,
661
+ ) -> str:
662
+ """Send a chat completion request to Gemini."""
663
+ if self._is_quota_exhausted():
664
+ logger.debug("Gemini quota circuit breaker active β€” skipping complete()")
665
+ return ""
666
+
667
+ model = self._map_model(model or self.default_model)
668
+ payload = {
669
+ "model": model,
670
+ "messages": [
671
+ {"role": "system", "content": system},
672
+ {"role": "user", "content": user_message},
673
+ ],
674
+ "temperature": temperature,
675
+ "max_tokens": max_tokens,
676
+ }
677
+
678
+ for attempt in range(self.max_retries):
679
+ try:
680
+ await self._wait_for_rate_limit()
681
+ resp = await self._http.post("chat/completions", json=payload)
682
+ resp.raise_for_status()
683
+ data = resp.json()
684
+ usage = data.get("usage", {})
685
+ self.usage.record(model, usage.get("prompt_tokens", 0), usage.get("completion_tokens", 0))
686
+ return data["choices"][0]["message"]["content"]
687
+ except httpx.HTTPStatusError as e:
688
+ if e.response.status_code == 429:
689
+ retry_after = e.response.headers.get("retry-after", "5")
690
+ try:
691
+ wait = float(retry_after)
692
+ except (ValueError, TypeError):
693
+ wait = 5.0
694
+ if wait > 30:
695
+ self._rate_limited_until = time.monotonic() + wait
696
+ logger.warning(f"Gemini quota exhausted for {wait:.0f}s")
697
+ return ""
698
+ logger.warning(f"Gemini rate limited, waiting {wait}s")
699
+ await asyncio.sleep(wait)
700
+ else:
701
+ logger.error(f"Gemini HTTP error: {e.response.status_code}")
702
+ if attempt == self.max_retries - 1:
703
+ return ""
704
+ await asyncio.sleep(1)
705
+ except Exception as e:
706
+ logger.error(f"Gemini error: {e}")
707
+ if attempt == self.max_retries - 1:
708
+ return ""
709
+ await asyncio.sleep(1)
710
+ return ""
711
+
712
+ async def complete_json(
713
+ self,
714
+ system: str,
715
+ user_message: str,
716
+ model: Optional[str] = None,
717
+ temperature: float = 0.7,
718
+ max_tokens: int = 1024,
719
+ ) -> dict:
720
+ """Send a JSON-mode request to Gemini."""
721
+ if self._is_quota_exhausted():
722
+ logger.debug("Gemini quota circuit breaker active β€” skipping complete_json()")
723
+ return {}
724
+
725
+ model = self._map_model(model or self.default_model)
726
+ json_instruction = (
727
+ "\n\nRespond ONLY with valid JSON. No markdown, no explanation, no extra text. "
728
+ "Just the JSON object."
729
+ )
730
+ payload = {
731
+ "model": model,
732
+ "messages": [
733
+ {"role": "system", "content": system},
734
+ {"role": "user", "content": user_message + json_instruction},
735
+ ],
736
+ "temperature": temperature,
737
+ "max_tokens": max_tokens,
738
+ "response_format": {"type": "json_object"},
739
+ }
740
+
741
+ for attempt in range(self.max_retries):
742
+ try:
743
+ await self._wait_for_rate_limit()
744
+ resp = await self._http.post("chat/completions", json=payload)
745
+ resp.raise_for_status()
746
+ data = resp.json()
747
+ usage = data.get("usage", {})
748
+ self.usage.record(model, usage.get("prompt_tokens", 0), usage.get("completion_tokens", 0))
749
+ text = data["choices"][0]["message"]["content"]
750
+ return _parse_json_response(text)
751
+ except httpx.HTTPStatusError as e:
752
+ if e.response.status_code == 429:
753
+ retry_after = e.response.headers.get("retry-after", "5")
754
+ try:
755
+ wait = float(retry_after)
756
+ except (ValueError, TypeError):
757
+ wait = 5.0
758
+ if wait > 30:
759
+ self._rate_limited_until = time.monotonic() + wait
760
+ logger.warning(f"Gemini quota exhausted for {wait:.0f}s")
761
+ return {}
762
+ logger.warning(f"Gemini rate limited, waiting {wait}s")
763
+ await asyncio.sleep(wait)
764
+ else:
765
+ logger.error(f"Gemini JSON error: {e.response.status_code}")
766
+ if attempt == self.max_retries - 1:
767
+ return {}
768
+ await asyncio.sleep(1)
769
+ except Exception as e:
770
+ logger.error(f"Gemini JSON error: {e}")
771
+ if attempt == self.max_retries - 1:
772
+ return {}
773
+ await asyncio.sleep(1)
774
+ return {}
775
+
776
+
777
  # ============================================================
778
  # Factory β€” create the right client based on config
779
  # ============================================================
 
796
  provider = os.environ.get("LLM_PROVIDER", "").lower()
797
 
798
  if not provider:
799
+ # Auto-detect: Claude β†’ Groq β†’ Gemini β†’ Ollama
800
  if os.environ.get("ANTHROPIC_API_KEY"):
801
  provider = PROVIDER_CLAUDE
802
  elif os.environ.get("GROQ_API_KEY"):
803
  provider = PROVIDER_GROQ
804
+ elif os.environ.get("GEMINI_API_KEY"):
805
+ provider = PROVIDER_GEMINI
806
  else:
807
  provider = PROVIDER_OLLAMA
808
 
 
812
  elif provider == PROVIDER_GROQ:
813
  default_model = model or os.environ.get("GROQ_MODEL", MODEL_GROQ_LLAMA_8B)
814
  return GroqClient(default_model=default_model)
815
+ elif provider == PROVIDER_GEMINI:
816
+ default_model = model or os.environ.get("GEMINI_MODEL", MODEL_GEMINI_FLASH)
817
+ return GeminiClient(default_model=default_model)
818
  elif provider == PROVIDER_OLLAMA:
819
  default_model = model or os.environ.get("OLLAMA_MODEL", MODEL_LLAMA)
820
  return OllamaClient(base_url=ollama_url, default_model=default_model)
821
  else:
822
+ raise ValueError(f"Unknown LLM provider: {provider}. Use 'claude', 'groq', 'gemini', or 'ollama'.")
823
 
824
 
825
  # --- Prompt Templates ---
web/index.html CHANGED
@@ -748,8 +748,19 @@ function animate() {
748
  // Moving agents travel slower so the walk is visible; others snap faster
749
  const isMoving = agent && (agent.state === 'moving');
750
  const lerpRate = isMoving ? 0.022 : 0.07;
 
751
  p.x += (dest.x - p.x) * lerpRate;
752
  p.y += (dest.y - p.y) * lerpRate;
 
 
 
 
 
 
 
 
 
 
753
  }
754
  draw();
755
  requestAnimationFrame(animate);
@@ -1948,17 +1959,7 @@ function drawPerson(id, agent, globalIdx, W, H) {
1948
  const armSwing = walkAnim ? Math.sin(walkPhase) * 10 : 0;
1949
  const tY = -10 + bounce; // torso top Y β€” hoisted so profile view can use it
1950
 
1951
- // Facing direction β€” track dominant movement axis (H or V)
1952
- const destPt = (agentWaypoints[id] && agentWaypoints[id].length) ? agentWaypoints[id][0] : agentTargets[id];
1953
- if (destPt && walkAnim) {
1954
- const ddx = destPt.x - ax, ddy = destPt.y - ay;
1955
- if (Math.abs(ddx) > Math.abs(ddy) + 5) {
1956
- agentFacingRight[id] = ddx > 0;
1957
- agentMovingUp[id] = false;
1958
- } else if (Math.abs(ddy) > Math.abs(ddx) + 5) {
1959
- agentMovingUp[id] = ddy < 0;
1960
- }
1961
- }
1962
  const facingRight = agentFacingRight[id] !== false; // default right
1963
  const movingUp = agentMovingUp[id] === true;
1964
 
 
748
  // Moving agents travel slower so the walk is visible; others snap faster
749
  const isMoving = agent && (agent.state === 'moving');
750
  const lerpRate = isMoving ? 0.022 : 0.07;
751
+ const prevX = p.x, prevY = p.y;
752
  p.x += (dest.x - p.x) * lerpRate;
753
  p.y += (dest.y - p.y) * lerpRate;
754
+ // Track facing direction from actual pixel delta β€” reliable for all path types
755
+ const mdx = p.x - prevX, mdy = p.y - prevY;
756
+ if (Math.abs(mdx) > 0.1 || Math.abs(mdy) > 0.1) {
757
+ if (Math.abs(mdy) > Math.abs(mdx)) {
758
+ agentMovingUp[id] = mdy < 0; // moving up = back view
759
+ } else {
760
+ agentFacingRight[id] = mdx > 0; // moving horizontally = profile
761
+ agentMovingUp[id] = false;
762
+ }
763
+ }
764
  }
765
  draw();
766
  requestAnimationFrame(animate);
 
1959
  const armSwing = walkAnim ? Math.sin(walkPhase) * 10 : 0;
1960
  const tY = -10 + bounce; // torso top Y β€” hoisted so profile view can use it
1961
 
1962
+ // Facing direction β€” maintained by animate() from position delta
 
 
 
 
 
 
 
 
 
 
1963
  const facingRight = agentFacingRight[id] !== false; // default right
1964
  const movingUp = agentMovingUp[id] === true;
1965