| | """Neural Network LLM client β replaces cloud LLM with local ONNX model. |
| | |
| | Downloads soci-agent-nn from HuggingFace Hub on first use, then runs |
| | inference via ONNX Runtime. Zero API calls, zero cost, ~1ms per batch. |
| | |
| | Drop-in replacement for GeminiClient/GroqClient β implements the same |
| | complete() / complete_json() interface expected by Simulation. |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import json |
| | import logging |
| | import math |
| | import os |
| | import random |
| | import re |
| | from dataclasses import dataclass, field |
| | from pathlib import Path |
| | from typing import Optional |
| |
|
| | import numpy as np |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | try: |
| | import onnxruntime as ort |
| | except ImportError: |
| | ort = None |
| |
|
| | |
| |
|
| | ACTION_TYPES = ["move", "work", "eat", "sleep", "talk", "exercise", "shop", "relax", "wander"] |
| | ACTION_TO_IDX = {a: i for i, a in enumerate(ACTION_TYPES)} |
| |
|
| | LOCATIONS = [ |
| | |
| | "house_elena", "house_marcus", "house_helen", "house_diana", "house_kai", |
| | "house_priya", "house_james", "house_rosa", "house_yuki", "house_frank", |
| | "apartment_block_1", "apartment_block_2", "apartment_block_3", |
| | "apt_northeast", "apt_northwest", "apt_southeast", "apt_southwest", |
| | |
| | "cafe", "grocery", "bar", "restaurant", "bakery", "cinema", "diner", "pharmacy", |
| | |
| | "office", "office_tower", "factory", "school", "hospital", |
| | |
| | "park", "gym", "library", "church", "town_square", "sports_field", |
| | "street_north", "street_south", "street_east", "street_west", |
| | ] |
| | LOC_TO_IDX = {loc: i for i, loc in enumerate(LOCATIONS)} |
| |
|
| | NEED_NAMES = ["hunger", "energy", "social", "purpose", "comfort", "fun"] |
| |
|
| | ACTION_DURATIONS = {"move": 1, "work": 4, "eat": 2, "sleep": 8, "talk": 2, "exercise": 3, "shop": 2, "relax": 2, "wander": 1} |
| |
|
| | FEATURE_DIM = 47 |
| |
|
| | |
| | |
| | _GREETINGS = [ |
| | "Hey, how's it going?", "Morning!", "Hi there.", "What's up?", |
| | "Hey! Haven't seen you in a while.", "Oh hey, I was just thinking about you.", |
| | ] |
| | _SMALLTALK = [ |
| | "Nice weather today, isn't it?", "Been busy lately?", |
| | "Did you hear about the event in the square?", "I've been meaning to ask you something.", |
| | "This place is getting crowded.", "I could really use a coffee.", |
| | ] |
| | _REPLIES = [ |
| | "Yeah, totally.", "I know what you mean.", "Ha, right?", |
| | "That's interesting.", "Tell me more.", "I hadn't thought about it that way.", |
| | "Hmm, I'm not so sure about that.", "Oh really?", "Same here.", |
| | ] |
| |
|
| |
|
| | |
| |
|
| | def _time_period(hour: int) -> int: |
| | if hour < 6: return 0 |
| | if hour < 9: return 1 |
| | if hour < 12: return 2 |
| | if hour < 14: return 3 |
| | if hour < 18: return 4 |
| | if hour < 22: return 5 |
| | return 6 |
| |
|
| |
|
| | def encode_features( |
| | personality: dict[str, float], |
| | age: float, |
| | hour: int, |
| | minute: int, |
| | day: int, |
| | needs: dict[str, float], |
| | mood: float, |
| | current_loc: str, |
| | home_loc: str = "", |
| | work_loc: str = "", |
| | num_people_here: int = 0, |
| | ) -> np.ndarray: |
| | """Encode agent state into the 47-dim feature vector the ONNX model expects.""" |
| | f: list[float] = [] |
| |
|
| | |
| | f.append(personality.get("openness", 5) / 10.0) |
| | f.append(personality.get("conscientiousness", 5) / 10.0) |
| | f.append(personality.get("extraversion", 5) / 10.0) |
| | f.append(personality.get("agreeableness", 5) / 10.0) |
| | f.append(personality.get("neuroticism", 5) / 10.0) |
| |
|
| | |
| | f.append(age / 100.0) |
| |
|
| | |
| | f.append(math.sin(2 * math.pi * hour / 24)) |
| | f.append(math.cos(2 * math.pi * hour / 24)) |
| | f.append(math.sin(2 * math.pi * minute / 60)) |
| | f.append(math.cos(2 * math.pi * minute / 60)) |
| |
|
| | |
| | dow = (day - 1) % 7 |
| | f.append(dow / 7.0) |
| | f.append(1.0 if dow >= 5 else 0.0) |
| |
|
| | |
| | for n in NEED_NAMES: |
| | f.append(needs.get(n, 0.5)) |
| |
|
| | |
| | f.append(max(-1.0, min(1.0, mood))) |
| |
|
| | |
| | vals = [needs.get(n, 0.5) for n in NEED_NAMES] |
| | urgent_idx = int(np.argmin(vals)) |
| | f.append(urgent_idx / 5.0) |
| |
|
| | |
| | f.append(1.0 if any(v < 0.15 for v in vals) else 0.0) |
| |
|
| | |
| | zone = 0 if current_loc.startswith(("house_", "apartment_", "apt_")) else ( |
| | 1 if current_loc in ("cafe", "grocery", "bar", "restaurant", "bakery", "cinema", "diner", "pharmacy") else ( |
| | 2 if current_loc in ("office", "office_tower", "factory", "school", "hospital") else 3)) |
| | f.append(zone / 3.0) |
| | f.append(1.0 if current_loc == home_loc else 0.0) |
| | f.append(1.0 if current_loc == work_loc else 0.0) |
| | f.append(min(num_people_here / 10.0, 1.0)) |
| |
|
| | |
| | loc_oh = [0.0] * 6 |
| | if zone == 0: |
| | loc_oh[0] = 1.0 |
| | elif zone == 1: |
| | loc_oh[1] = 1.0 |
| | elif zone == 2: |
| | loc_oh[2] = 1.0 |
| | elif current_loc.startswith("street_"): |
| | loc_oh[4] = 1.0 |
| | else: |
| | loc_oh[3] = 1.0 |
| | if current_loc == home_loc: |
| | loc_oh[5] = 1.0 |
| | f.extend(loc_oh) |
| |
|
| | |
| | tp = [0.0] * 7 |
| | tp[_time_period(hour)] = 1.0 |
| | f.extend(tp) |
| |
|
| | |
| | f.extend([0.0] * 9) |
| |
|
| | return np.array([f], dtype=np.float32) |
| |
|
| |
|
| | |
| |
|
| | def _extract_persona_from_system(system: str) -> dict: |
| | """Pull personality traits, age, home/work from the system prompt.""" |
| | info: dict = {} |
| | |
| | for trait in ("openness", "conscientiousness", "extraversion", "agreeableness", "neuroticism"): |
| | |
| | |
| | info[trait] = 5 |
| | |
| | age_m = re.search(r"(\d+)-year-old", system) |
| | if age_m: |
| | info["age"] = int(age_m.group(1)) |
| | else: |
| | info["age"] = 30 |
| | return info |
| |
|
| |
|
| | def _extract_state_from_user(user_message: str) -> dict: |
| | """Extract time, location, needs from the user prompt.""" |
| | state: dict = {"hour": 12, "minute": 0, "day": 1, "location": "", "needs": {}, "mood": 0.0} |
| |
|
| | |
| | time_m = re.search(r"(\d{1,2}):(\d{2})", user_message) |
| | if time_m: |
| | state["hour"] = int(time_m.group(1)) |
| | state["minute"] = int(time_m.group(2)) |
| | day_m = re.search(r"Day\s+(\d+)", user_message) |
| | if day_m: |
| | state["day"] = int(day_m.group(1)) |
| |
|
| | |
| | loc_m = re.search(r"at (\w[\w\s&']+?)[\.\,]", user_message) |
| | if loc_m: |
| | loc_name = loc_m.group(1).strip().lower() |
| | for loc_id in LOCATIONS: |
| | if loc_id.replace("_", " ") in loc_name or loc_name in loc_id: |
| | state["location"] = loc_id |
| | break |
| |
|
| | |
| | for need in NEED_NAMES: |
| | nm = re.search(rf"{need}\s*[=:]\s*([\d.]+)", user_message, re.IGNORECASE) |
| | if nm: |
| | state["needs"][need] = float(nm.group(1)) |
| |
|
| | return state |
| |
|
| |
|
| | |
| |
|
| | _DEFAULT_REPO = "RayMelius/soci-agent-nn" |
| | _MODEL_FILENAME = "soci_agent.onnx" |
| |
|
| |
|
| | def _download_model(repo_id: str = _DEFAULT_REPO, cache_dir: str = "models") -> str: |
| | """Download the ONNX model from HuggingFace Hub if not cached.""" |
| | cache = Path(cache_dir) |
| | cache.mkdir(parents=True, exist_ok=True) |
| | local_path = cache / _MODEL_FILENAME |
| |
|
| | if local_path.exists(): |
| | logger.info(f"NN model cached at {local_path}") |
| | return str(local_path) |
| |
|
| | logger.info(f"Downloading NN model from {repo_id}...") |
| | try: |
| | from huggingface_hub import hf_hub_download |
| | downloaded = hf_hub_download( |
| | repo_id=repo_id, |
| | filename=_MODEL_FILENAME, |
| | local_dir=str(cache), |
| | ) |
| | logger.info(f"NN model downloaded to {downloaded}") |
| | return downloaded |
| | except ImportError: |
| | |
| | import httpx as _httpx |
| | url = f"https://huggingface.co/{repo_id}/resolve/main/{_MODEL_FILENAME}" |
| | logger.info(f"Downloading from {url}") |
| | resp = _httpx.get(url, follow_redirects=True, timeout=120.0) |
| | resp.raise_for_status() |
| | local_path.write_bytes(resp.content) |
| | logger.info(f"NN model saved to {local_path} ({len(resp.content):,} bytes)") |
| | return str(local_path) |
| |
|
| |
|
| | |
| |
|
| | @dataclass |
| | class NNUsage: |
| | calls: int = 0 |
| | def summary(self) -> str: |
| | return f"calls: {self.calls}, $0.00" |
| | def record(self, *_args, **_kwargs) -> None: |
| | self.calls += 1 |
| |
|
| |
|
| | |
| |
|
| | class NNClient: |
| | """ONNX-based neural network client β drop-in LLM replacement for Soci. |
| | |
| | Downloads soci-agent-nn from HuggingFace Hub on first use. |
| | Runs inference via ONNX Runtime on CPU (~1ms for 50 agents). |
| | Zero API calls, zero cost, works offline. |
| | """ |
| |
|
| | provider = "nn" |
| | default_model = "soci-agent-nn" |
| | llm_status = "active" |
| |
|
| | def __init__(self, model_path: Optional[str] = None, repo_id: str = _DEFAULT_REPO): |
| | if ort is None: |
| | raise ImportError( |
| | "onnxruntime is required for the NN provider. " |
| | "Install it with: pip install onnxruntime" |
| | ) |
| | self._repo_id = repo_id |
| | if model_path is None: |
| | model_path = _download_model(repo_id) |
| | self._model_path = model_path |
| | self.session = ort.InferenceSession( |
| | model_path, |
| | providers=["CPUExecutionProvider"], |
| | ) |
| | self.usage = NNUsage() |
| | self._last_error = "" |
| | logger.info(f"NN client loaded: {model_path}") |
| |
|
| | def reload(self) -> str: |
| | """Re-download the ONNX model from HF Hub and reload the session. |
| | |
| | Returns a status message describing what happened. |
| | """ |
| | local_path = Path(self._model_path) |
| |
|
| | |
| | if local_path.exists(): |
| | old_size = local_path.stat().st_size |
| | local_path.unlink() |
| | logger.info(f"Deleted cached model ({old_size:,} bytes)") |
| |
|
| | |
| | new_path = _download_model(self._repo_id) |
| | new_size = Path(new_path).stat().st_size |
| |
|
| | |
| | self.session = ort.InferenceSession( |
| | new_path, |
| | providers=["CPUExecutionProvider"], |
| | ) |
| | self._model_path = new_path |
| |
|
| | msg = f"NN model reloaded from {self._repo_id} ({new_size / 1024:.0f} KB)" |
| | logger.info(msg) |
| | return msg |
| |
|
| | async def complete( |
| | self, |
| | system: str, |
| | user_message: str, |
| | model: Optional[str] = None, |
| | temperature: float = 0.7, |
| | max_tokens: int = 1024, |
| | ) -> str: |
| | """Return a JSON string action decision from the NN model.""" |
| | result = await self.complete_json( |
| | system=system, |
| | user_message=user_message, |
| | model=model, |
| | temperature=temperature, |
| | max_tokens=max_tokens, |
| | ) |
| | return json.dumps(result) if result else "" |
| |
|
| | async def complete_json( |
| | self, |
| | system: str, |
| | user_message: str, |
| | model: Optional[str] = None, |
| | temperature: float = 0.7, |
| | max_tokens: int = 1024, |
| | ) -> dict: |
| | """Parse prompt context and run the NN to produce an action/conversation/plan. |
| | |
| | Detects the prompt type (action decision, conversation, plan, reflection) |
| | from the user_message content and routes to the appropriate handler. |
| | """ |
| | self.usage.record() |
| |
|
| | |
| | msg_lower = user_message.lower() |
| | if "plan your day" in msg_lower or "what will you do today" in msg_lower: |
| | return self._generate_plan(system, user_message) |
| | elif '"action"' in msg_lower and '"target"' in msg_lower: |
| | return self._decide_action(system, user_message, temperature) |
| | elif "how do you respond" in msg_lower or "you decide to start a conversation" in msg_lower: |
| | return self._generate_conversation(system, user_message) |
| | elif "reflect on your recent" in msg_lower: |
| | return self._generate_reflection(system, user_message) |
| | elif "how important is this" in msg_lower: |
| | return {"importance": random.randint(3, 7), "reaction": "Interesting."} |
| | else: |
| | |
| | return self._decide_action(system, user_message, temperature) |
| |
|
| | def _decide_action(self, system: str, user_message: str, temperature: float = 0.7) -> dict: |
| | """Run the ONNX model to select an action.""" |
| | persona = _extract_persona_from_system(system) |
| | state = _extract_state_from_user(user_message) |
| |
|
| | |
| | needs = state["needs"] |
| | for n in NEED_NAMES: |
| | if n not in needs: |
| | needs[n] = 0.5 |
| |
|
| | features = encode_features( |
| | personality=persona, |
| | age=persona.get("age", 30), |
| | hour=state["hour"], |
| | minute=state["minute"], |
| | day=state["day"], |
| | needs=needs, |
| | mood=state.get("mood", 0.0), |
| | current_loc=state.get("location", "town_square"), |
| | home_loc="", |
| | work_loc="", |
| | num_people_here=0, |
| | ) |
| |
|
| | |
| | outputs = self.session.run(None, {"features": features}) |
| | action_logits = outputs[0][0] |
| | location_logits = outputs[1][0] |
| | duration_pred = outputs[2][0] if len(outputs) > 2 else 2.0 |
| |
|
| | |
| | logits = action_logits / max(temperature, 0.1) |
| | exp_logits = np.exp(logits - np.max(logits)) |
| | probs = exp_logits / exp_logits.sum() |
| | action_idx = int(np.random.choice(len(ACTION_TYPES), p=probs)) |
| | action = ACTION_TYPES[action_idx] |
| |
|
| | |
| | loc_idx = int(np.argmax(location_logits)) |
| | target = LOCATIONS[loc_idx] if loc_idx < len(LOCATIONS) else "" |
| |
|
| | |
| | duration = max(1, min(8, round(float(duration_pred)))) |
| | if action in ACTION_DURATIONS and abs(duration - ACTION_DURATIONS[action]) > 3: |
| | duration = ACTION_DURATIONS[action] |
| |
|
| | return { |
| | "action": action, |
| | "target": target, |
| | "detail": f"NN: {action} at {target}", |
| | "duration": duration, |
| | "reasoning": f"NN model (conf: {probs[action_idx]:.0%})", |
| | } |
| |
|
| | def _generate_plan(self, system: str, user_message: str) -> dict: |
| | """Generate a simple daily plan based on persona and time.""" |
| | persona = _extract_persona_from_system(system) |
| | state = _extract_state_from_user(user_message) |
| |
|
| | |
| | plan = [] |
| | if state["hour"] <= 8: |
| | plan.append("Have breakfast") |
| | plan.append("Head to work") |
| | plan.append("Work through the morning") |
| | plan.append("Lunch break") |
| | plan.append("Afternoon work session") |
| |
|
| | |
| | E = persona.get("extraversion", 5) |
| | if E >= 7: |
| | plan.append("Meet friends for dinner") |
| | plan.append("Go to the bar") |
| | elif E >= 4: |
| | plan.append("Dinner at a restaurant") |
| | plan.append("Relaxing walk in the park") |
| | else: |
| | plan.append("Quiet dinner at home") |
| | plan.append("Read a book") |
| |
|
| | plan.append("Get some sleep") |
| |
|
| | return {"plan": plan, "reasoning": "NN-generated daily plan"} |
| |
|
| | def _generate_conversation(self, system: str, user_message: str) -> dict: |
| | """Generate a conversation turn.""" |
| | |
| | if "you decide to start a conversation" in user_message.lower(): |
| | return { |
| | "message": random.choice(_GREETINGS), |
| | "inner_thought": "Let's see what they're up to.", |
| | "topic": random.choice(["catching up", "the weather", "what's new", "plans"]), |
| | } |
| | else: |
| | |
| | return { |
| | "message": random.choice(_REPLIES + _SMALLTALK), |
| | "inner_thought": "Interesting conversation.", |
| | "sentiment_delta": round(random.uniform(-0.02, 0.05), 3), |
| | "trust_delta": round(random.uniform(-0.01, 0.03), 3), |
| | } |
| |
|
| | def _generate_reflection(self, system: str, user_message: str) -> dict: |
| | """Generate a reflection with mood shift.""" |
| | reflections = [ |
| | "Things have been going well lately.", |
| | "I should spend more time doing what I enjoy.", |
| | "The people around me make this place feel like home.", |
| | ] |
| | return { |
| | "reflections": random.sample(reflections, k=min(2, len(reflections))), |
| | "mood_shift": round(random.uniform(-0.1, 0.15), 2), |
| | "reasoning": "Reflecting on recent experiences.", |
| | } |
| |
|