Spaces:
Sleeping
Sleeping
adityaverma977 commited on
Commit ·
2a49cf3
1
Parent(s): 7b69e72
Update HF router models and inference flow
Browse files- backend/app/groq_client.py +195 -208
- backend/app/hf_spaces.py +87 -30
- backend/app/main.py +2 -1
backend/app/groq_client.py
CHANGED
|
@@ -1,274 +1,261 @@
|
|
| 1 |
import json
|
|
|
|
| 2 |
import os
|
| 3 |
import random
|
| 4 |
-
import
|
|
|
|
| 5 |
import httpx
|
| 6 |
from dotenv import load_dotenv
|
| 7 |
|
| 8 |
-
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
_HF_API_BASE = "https://api-inference.huggingface.co/models"
|
| 13 |
|
| 14 |
MAX_AGENT_SPEED = 80
|
| 15 |
|
| 16 |
-
|
| 17 |
-
print(f"[GROQ_CLIENT_INIT] HF_API_TOKEN present: {_HF_API_TOKEN is not None and len(_HF_API_TOKEN) > 0}")
|
| 18 |
if not _HF_API_TOKEN:
|
| 19 |
print("[GROQ_CLIENT_INIT] WARNING: No HF API token found! Set HF_API_TOKEN or HUGGINGFACE_API_TOKEN env var.")
|
| 20 |
|
| 21 |
-
# Curated HF model ids verified to work with HF Inference API
|
| 22 |
-
HF_MODELS = [
|
| 23 |
-
"mistralai/Mistral-7B-Instruct-v0.2",
|
| 24 |
-
"mistralai/Mistral-7B-Instruct-v0.1",
|
| 25 |
-
"NousResearch/Nous-Hermes-2-7b",
|
| 26 |
-
"HuggingFaceH4/zephyr-7b-beta",
|
| 27 |
-
"tiiuae/falcon-7b-instruct",
|
| 28 |
-
"meta-llama/Llama-2-7b-chat-hf",
|
| 29 |
-
"meta-llama/Llama-2-13b-chat-hf",
|
| 30 |
-
"stabilityai/stablelm-tuned-alpha-3b",
|
| 31 |
-
"WizardLM/WizardLM-7B-V1.0",
|
| 32 |
-
]
|
| 33 |
-
|
| 34 |
|
| 35 |
def is_ready():
|
| 36 |
-
|
| 37 |
-
return _HF_API_TOKEN is not None
|
| 38 |
|
| 39 |
|
| 40 |
-
def
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
|
| 44 |
def _generate_chat_message(action: str, agent_name: str, fire_distance: float, has_water: bool) -> str:
|
| 45 |
-
"""Generate a contextual chat message based on action and state."""
|
| 46 |
-
import random
|
| 47 |
-
|
| 48 |
action_messages = {
|
| 49 |
"search_water": [
|
| 50 |
f"{agent_name} is hunting for water...",
|
| 51 |
-
f"
|
| 52 |
-
"
|
| 53 |
-
|
| 54 |
-
"Locating nearest well...",
|
| 55 |
-
"Water mission initiated!",
|
| 56 |
],
|
| 57 |
"collect_water": [
|
| 58 |
-
f"{agent_name}
|
| 59 |
-
"
|
| 60 |
-
|
| 61 |
-
"
|
| 62 |
-
"Tank is full, let's go!",
|
| 63 |
-
f"{agent_name} loading water supply...",
|
| 64 |
],
|
| 65 |
"extinguish_fire": [
|
| 66 |
-
f"{agent_name}
|
| 67 |
-
"
|
| 68 |
-
|
| 69 |
-
"
|
| 70 |
-
"Taking the fight to the fire!",
|
| 71 |
-
f"{agent_name} is fighting hard!",
|
| 72 |
],
|
| 73 |
"escape": [
|
| 74 |
-
f"{agent_name}
|
| 75 |
-
"
|
| 76 |
-
"
|
| 77 |
-
|
| 78 |
-
"Backing away from danger!",
|
| 79 |
-
"Moving to safer ground...",
|
| 80 |
],
|
| 81 |
"vote_for_leader": [
|
| 82 |
-
f"{agent_name}
|
| 83 |
-
"
|
| 84 |
-
"
|
| 85 |
-
|
| 86 |
-
"Let's coordinate and dominate!",
|
| 87 |
-
"Voting for strategic leadership...",
|
| 88 |
],
|
| 89 |
}
|
| 90 |
-
|
| 91 |
messages = action_messages.get(action, action_messages["escape"])
|
| 92 |
return random.choice(messages)
|
| 93 |
|
| 94 |
|
| 95 |
def _build_fire_state_summary(agent, fire, all_agents) -> str:
|
| 96 |
-
"""Build a state summary for the fire scenario."""
|
| 97 |
standings = []
|
| 98 |
-
for
|
| 99 |
-
if not
|
| 100 |
continue
|
| 101 |
-
|
| 102 |
standings.append({
|
| 103 |
-
"name":
|
| 104 |
-
"
|
| 105 |
-
"
|
| 106 |
-
"x": a.x,
|
| 107 |
-
"y": a.y,
|
| 108 |
-
"has_water": a.water_collected,
|
| 109 |
-
"mode": a.mode,
|
| 110 |
})
|
| 111 |
|
| 112 |
-
standings.sort(key=lambda
|
| 113 |
-
|
| 114 |
lines = ["Current standings:"]
|
| 115 |
-
for
|
| 116 |
-
|
| 117 |
-
lines.append(f"
|
| 118 |
-
|
| 119 |
return "\n".join(lines)
|
| 120 |
|
| 121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
async def generate_fire_decision(agent, fire, water_sources, other_agents, bounds, recent_radio=None) -> dict:
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
"""
|
| 127 |
-
if not is_ready(): print(f"[INFERENCE_FAIL] {agent.model_name}: HF token not ready, using fallback") return _fallback_escape(agent, fire)
|
| 128 |
|
| 129 |
dist_to_fire = math.dist((agent.x, agent.y), (fire.x, fire.y))
|
| 130 |
-
nearest_water = min(water_sources, key=lambda
|
| 131 |
dist_to_water = math.dist((agent.x, agent.y), (nearest_water.x, nearest_water.y)) if nearest_water else None
|
| 132 |
-
|
| 133 |
-
living_agents = [
|
| 134 |
state_summary = _build_fire_state_summary(agent, fire, [agent] + living_agents)
|
| 135 |
radio_summary = "\n".join(recent_radio or []) if recent_radio else "(no recent chat yet)"
|
| 136 |
-
|
| 137 |
-
coalition_leader = next((a.model_name for a in other_agents if a.is_leader), None)
|
| 138 |
dist_to_water_display = f"{dist_to_water:.0f}px" if dist_to_water is not None else "unknown"
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
- A wildfire is spreading
|
| 144 |
-
- Water
|
| 145 |
-
-
|
| 146 |
-
-
|
| 147 |
-
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
- If
|
| 160 |
-
-
|
| 161 |
-
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
-
|
| 167 |
-
-
|
| 168 |
-
-
|
| 169 |
-
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
Fire radius: {fire.radius:.0f}px
|
| 176 |
-
Fire intensity: {fire.intensity:.0f}%
|
| 177 |
-
Carrying water: {agent.water_collected}
|
| 178 |
-
Mode: {agent.mode} ({'joined a coalition' if agent.mode == 'coalition' else 'acting alone'})
|
| 179 |
-
Nearest water distance: {dist_to_water_display}
|
| 180 |
-
Coalition leader: {coalition_leader or 'none'}
|
| 181 |
-
|
| 182 |
-
RECENT RADIO CHAT:
|
| 183 |
{radio_summary}
|
| 184 |
|
| 185 |
{state_summary}
|
| 186 |
|
| 187 |
-
|
| 188 |
-
{{"action":
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
if isinstance(data, list) and len(data) > 0:
|
| 214 |
-
text = data[0].get("generated_text", "")
|
| 215 |
-
else:
|
| 216 |
-
text = data.get("generated_text", "")
|
| 217 |
-
|
| 218 |
-
text = text[len(system_prompt):].strip() if text.startswith(system_prompt) else text
|
| 219 |
-
|
| 220 |
-
# DEBUG: log raw response for inspection
|
| 221 |
-
print(f"[HF_INFERENCE] {agent.model_name}: raw response (first 300 chars): {text[:300]}")
|
| 222 |
-
|
| 223 |
-
try:
|
| 224 |
-
json_start = text.find('{')
|
| 225 |
-
json_end = text.rfind('}') + 1
|
| 226 |
-
if json_start >= 0 and json_end > json_start:
|
| 227 |
-
json_str = text[json_start:json_end]
|
| 228 |
-
decision = json.loads(json_str)
|
| 229 |
-
print(f"[HF_INFERENCE] {agent.model_name}: decision parsed: action={decision.get('action')}, message={decision.get('message')}")
|
| 230 |
-
else:
|
| 231 |
-
print(f"[HF_INFERENCE] {agent.model_name}: no JSON found in response")
|
| 232 |
-
decision = {}
|
| 233 |
-
except json.JSONDecodeError as je:
|
| 234 |
-
print(f"[HF_INFERENCE] {agent.model_name}: JSON parse error: {je}")
|
| 235 |
-
decision = {}
|
| 236 |
-
|
| 237 |
-
action = decision.get("action", "escape")
|
| 238 |
-
if action not in ["search_water", "collect_water", "extinguish_fire", "escape", "vote_for_leader"]:
|
| 239 |
-
action = "escape"
|
| 240 |
-
|
| 241 |
-
# If no message extracted, generate one contextually
|
| 242 |
-
message = decision.get("message", "").strip()
|
| 243 |
-
if not message:
|
| 244 |
-
message = _generate_chat_message(action, agent.model_name, dist_to_fire, agent.water_collected)
|
| 245 |
-
print(f"[HF_INFERENCE] {agent.model_name}: generated message: {message}")
|
| 246 |
-
|
| 247 |
-
if dist_to_water is not None and dist_to_water <= 60 and not agent.water_collected:
|
| 248 |
-
action = "collect_water"
|
| 249 |
-
elif agent.water_collected and dist_to_fire <= 350:
|
| 250 |
-
action = "extinguish_fire"
|
| 251 |
-
|
| 252 |
-
return {
|
| 253 |
-
"action": action,
|
| 254 |
-
"vote_for": decision.get("vote_for"),
|
| 255 |
-
"message": message,
|
| 256 |
-
"reasoning": decision.get("reasoning", "Survival and teamwork.")
|
| 257 |
-
}
|
| 258 |
-
except Exception as e:
|
| 259 |
-
print(f"[HF_INFERENCE_ERROR] {agent.model_name}: {type(e).__name__}: {e}")
|
| 260 |
-
return _fallback_escape(agent, fire)
|
| 261 |
|
| 262 |
|
| 263 |
def _fallback_escape(agent, fire) -> dict:
|
| 264 |
-
"""Fallback escape behavior."""
|
| 265 |
-
dx = agent.x - fire.x
|
| 266 |
-
dy = agent.y - fire.y
|
| 267 |
-
dist = math.sqrt(dx**2 + dy**2) or 1
|
| 268 |
return {
|
| 269 |
"message": "Running to safety!",
|
| 270 |
"action": "escape",
|
| 271 |
"vote_for": None,
|
| 272 |
-
"reasoning": "Fallback: survive."
|
| 273 |
}
|
| 274 |
-
|
|
|
|
| 1 |
import json
|
| 2 |
+
import math
|
| 3 |
import os
|
| 4 |
import random
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
import httpx
|
| 8 |
from dotenv import load_dotenv
|
| 9 |
|
| 10 |
+
from . import hf_spaces
|
| 11 |
+
|
| 12 |
+
load_dotenv(Path(__file__).resolve().parents[1] / ".env")
|
| 13 |
|
| 14 |
+
_HF_API_TOKEN = (os.environ.get("HF_API_TOKEN") or os.environ.get("HUGGINGFACE_API_TOKEN") or "").strip()
|
| 15 |
+
_HF_CHAT_URL = "https://router.huggingface.co/v1/chat/completions"
|
|
|
|
| 16 |
|
| 17 |
MAX_AGENT_SPEED = 80
|
| 18 |
|
| 19 |
+
print(f"[GROQ_CLIENT_INIT] HF_API_TOKEN present: {bool(_HF_API_TOKEN)}")
|
|
|
|
| 20 |
if not _HF_API_TOKEN:
|
| 21 |
print("[GROQ_CLIENT_INIT] WARNING: No HF API token found! Set HF_API_TOKEN or HUGGINGFACE_API_TOKEN env var.")
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
def is_ready():
|
| 25 |
+
return bool(_HF_API_TOKEN)
|
|
|
|
| 26 |
|
| 27 |
|
| 28 |
+
def _headers() -> dict[str, str]:
|
| 29 |
+
if not _HF_API_TOKEN:
|
| 30 |
+
return {}
|
| 31 |
+
return {
|
| 32 |
+
"Authorization": f"Bearer {_HF_API_TOKEN}",
|
| 33 |
+
"Content-Type": "application/json",
|
| 34 |
+
}
|
| 35 |
|
| 36 |
|
| 37 |
def _generate_chat_message(action: str, agent_name: str, fire_distance: float, has_water: bool) -> str:
|
|
|
|
|
|
|
|
|
|
| 38 |
action_messages = {
|
| 39 |
"search_water": [
|
| 40 |
f"{agent_name} is hunting for water...",
|
| 41 |
+
f"{agent_name} is tracking the nearest well.",
|
| 42 |
+
"Need water before this gets worse.",
|
| 43 |
+
"Scanning for the fastest water route.",
|
|
|
|
|
|
|
| 44 |
],
|
| 45 |
"collect_water": [
|
| 46 |
+
f"{agent_name} is filling up now.",
|
| 47 |
+
"Got the well, taking water.",
|
| 48 |
+
"Water secured, moving out.",
|
| 49 |
+
"That should be enough to fight back.",
|
|
|
|
|
|
|
| 50 |
],
|
| 51 |
"extinguish_fire": [
|
| 52 |
+
f"{agent_name} is pushing the fire line.",
|
| 53 |
+
"Closing in with water.",
|
| 54 |
+
"Time to hit the flames.",
|
| 55 |
+
"Pressure on the fire now.",
|
|
|
|
|
|
|
| 56 |
],
|
| 57 |
"escape": [
|
| 58 |
+
f"{agent_name} is backing out.",
|
| 59 |
+
"Too hot here, pulling away.",
|
| 60 |
+
"Need space before the fire closes in.",
|
| 61 |
+
"Resetting position and staying alive.",
|
|
|
|
|
|
|
| 62 |
],
|
| 63 |
"vote_for_leader": [
|
| 64 |
+
f"{agent_name} wants a leader in place.",
|
| 65 |
+
"Coordination first, then pressure.",
|
| 66 |
+
"Picking a lead so we stop wasting ticks.",
|
| 67 |
+
"We need one caller right now.",
|
|
|
|
|
|
|
| 68 |
],
|
| 69 |
}
|
|
|
|
| 70 |
messages = action_messages.get(action, action_messages["escape"])
|
| 71 |
return random.choice(messages)
|
| 72 |
|
| 73 |
|
| 74 |
def _build_fire_state_summary(agent, fire, all_agents) -> str:
|
|
|
|
| 75 |
standings = []
|
| 76 |
+
for other in all_agents:
|
| 77 |
+
if not other.alive:
|
| 78 |
continue
|
| 79 |
+
distance = math.dist((other.x, other.y), (fire.x, fire.y))
|
| 80 |
standings.append({
|
| 81 |
+
"name": other.display_name,
|
| 82 |
+
"distance_from_fire": distance,
|
| 83 |
+
"has_water": other.water_collected,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
})
|
| 85 |
|
| 86 |
+
standings.sort(key=lambda item: item["distance_from_fire"])
|
|
|
|
| 87 |
lines = ["Current standings:"]
|
| 88 |
+
for index, item in enumerate(standings, 1):
|
| 89 |
+
suffix = " (carrying water)" if item["has_water"] else ""
|
| 90 |
+
lines.append(f"#{index} {item['name']}: {item['distance_from_fire']:.0f}px from fire{suffix}")
|
|
|
|
| 91 |
return "\n".join(lines)
|
| 92 |
|
| 93 |
|
| 94 |
+
def _extract_message_content(payload) -> str:
|
| 95 |
+
choices = payload.get("choices") or []
|
| 96 |
+
if not choices or not isinstance(choices[0], dict):
|
| 97 |
+
return ""
|
| 98 |
+
message = choices[0].get("message") or {}
|
| 99 |
+
content = message.get("content")
|
| 100 |
+
if isinstance(content, str):
|
| 101 |
+
return content.strip()
|
| 102 |
+
if isinstance(content, list):
|
| 103 |
+
parts = []
|
| 104 |
+
for item in content:
|
| 105 |
+
if isinstance(item, dict) and item.get("type") == "text" and isinstance(item.get("text"), str):
|
| 106 |
+
parts.append(item["text"])
|
| 107 |
+
return "".join(parts).strip()
|
| 108 |
+
return ""
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _extract_json_object(text: str) -> dict:
|
| 112 |
+
if not text:
|
| 113 |
+
return {}
|
| 114 |
+
|
| 115 |
+
cleaned = text.strip()
|
| 116 |
+
if cleaned.startswith("```"):
|
| 117 |
+
cleaned = cleaned.replace("```json", "").replace("```", "").strip()
|
| 118 |
+
|
| 119 |
+
start = cleaned.find("{")
|
| 120 |
+
end = cleaned.rfind("}") + 1
|
| 121 |
+
if start < 0 or end <= start:
|
| 122 |
+
return {}
|
| 123 |
+
|
| 124 |
+
try:
|
| 125 |
+
candidate = cleaned[start:end]
|
| 126 |
+
parsed = json.loads(candidate)
|
| 127 |
+
except json.JSONDecodeError:
|
| 128 |
+
return {}
|
| 129 |
+
|
| 130 |
+
return parsed if isinstance(parsed, dict) else {}
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def _normalize_decision(decision: dict, agent_name: str, dist_to_fire: float, has_water: bool) -> dict:
|
| 134 |
+
action = decision.get("action", "escape")
|
| 135 |
+
if action not in {"search_water", "collect_water", "extinguish_fire", "escape", "vote_for_leader"}:
|
| 136 |
+
action = "escape"
|
| 137 |
+
|
| 138 |
+
message = " ".join(str(decision.get("message", "")).strip().split())
|
| 139 |
+
if not message:
|
| 140 |
+
message = _generate_chat_message(action, agent_name, dist_to_fire, has_water)
|
| 141 |
+
|
| 142 |
+
vote_for = decision.get("vote_for")
|
| 143 |
+
if vote_for is not None and not isinstance(vote_for, str):
|
| 144 |
+
vote_for = None
|
| 145 |
+
|
| 146 |
+
reasoning = " ".join(str(decision.get("reasoning", "")).strip().split())
|
| 147 |
+
if not reasoning:
|
| 148 |
+
reasoning = "Survival and teamwork."
|
| 149 |
+
|
| 150 |
+
return {
|
| 151 |
+
"action": action,
|
| 152 |
+
"vote_for": vote_for,
|
| 153 |
+
"message": message[:220],
|
| 154 |
+
"reasoning": reasoning[:220],
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
async def _request_model_response(target_model: str, prompt: str) -> str:
|
| 159 |
+
payload = {
|
| 160 |
+
"model": target_model,
|
| 161 |
+
"messages": [{"role": "user", "content": prompt}],
|
| 162 |
+
"max_tokens": 220,
|
| 163 |
+
"temperature": 0.4,
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
async with httpx.AsyncClient(timeout=20.0) as client:
|
| 167 |
+
response = await client.post(_HF_CHAT_URL, headers=_headers(), json=payload)
|
| 168 |
+
response.raise_for_status()
|
| 169 |
+
data = response.json()
|
| 170 |
+
return _extract_message_content(data)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
async def generate_fire_decision(agent, fire, water_sources, other_agents, bounds, recent_radio=None) -> dict:
|
| 174 |
+
if not is_ready():
|
| 175 |
+
print(f"[INFERENCE_FAIL] {agent.model_name}: HF token not ready, using fallback")
|
| 176 |
+
return _fallback_escape(agent, fire)
|
|
|
|
|
|
|
| 177 |
|
| 178 |
dist_to_fire = math.dist((agent.x, agent.y), (fire.x, fire.y))
|
| 179 |
+
nearest_water = min(water_sources, key=lambda water: math.dist((agent.x, agent.y), (water.x, water.y))) if water_sources else None
|
| 180 |
dist_to_water = math.dist((agent.x, agent.y), (nearest_water.x, nearest_water.y)) if nearest_water else None
|
| 181 |
+
|
| 182 |
+
living_agents = [other for other in other_agents if other.alive and other.model_name != agent.model_name]
|
| 183 |
state_summary = _build_fire_state_summary(agent, fire, [agent] + living_agents)
|
| 184 |
radio_summary = "\n".join(recent_radio or []) if recent_radio else "(no recent chat yet)"
|
| 185 |
+
coalition_leader = next((other.model_name for other in other_agents if other.is_leader), None)
|
|
|
|
| 186 |
dist_to_water_display = f"{dist_to_water:.0f}px" if dist_to_water is not None else "unknown"
|
| 187 |
+
|
| 188 |
+
prompt = f"""You are {agent.model_name} in a wildfire survival simulation.
|
| 189 |
+
|
| 190 |
+
Scenario:
|
| 191 |
+
- A wildfire is spreading across the map
|
| 192 |
+
- Water wells are scattered around the area
|
| 193 |
+
- Agents can coordinate as a coalition and may vote for a leader
|
| 194 |
+
- Winning means getting water and using it to extinguish the fire
|
| 195 |
+
- Dying in the fire means losing
|
| 196 |
+
|
| 197 |
+
Allowed actions:
|
| 198 |
+
- search_water
|
| 199 |
+
- collect_water
|
| 200 |
+
- extinguish_fire
|
| 201 |
+
- escape
|
| 202 |
+
- vote_for_leader
|
| 203 |
+
|
| 204 |
+
Rules:
|
| 205 |
+
- If the fire is too close, prioritize survival
|
| 206 |
+
- If you already have water, move to the fire edge and fight it
|
| 207 |
+
- If you are at a well, collect water immediately
|
| 208 |
+
- Keep the message short, natural, and mission-focused
|
| 209 |
+
- Respond with only valid JSON on one line
|
| 210 |
+
|
| 211 |
+
Current state:
|
| 212 |
+
- Position: ({agent.x}, {agent.y})
|
| 213 |
+
- Fire position: ({fire.x}, {fire.y})
|
| 214 |
+
- Distance from fire: {dist_to_fire:.0f}px
|
| 215 |
+
- Fire radius: {fire.radius:.0f}px
|
| 216 |
+
- Fire intensity: {fire.intensity:.0f}%
|
| 217 |
+
- Carrying water: {agent.water_collected}
|
| 218 |
+
- Mode: {agent.mode}
|
| 219 |
+
- Nearest water distance: {dist_to_water_display}
|
| 220 |
+
- Coalition leader: {coalition_leader or 'none'}
|
| 221 |
+
|
| 222 |
+
Recent radio:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
{radio_summary}
|
| 224 |
|
| 225 |
{state_summary}
|
| 226 |
|
| 227 |
+
Return exactly:
|
| 228 |
+
{{"action":"search_water|collect_water|extinguish_fire|escape|vote_for_leader","vote_for":null,"message":"short sentence","reasoning":"short sentence"}}"""
|
| 229 |
+
|
| 230 |
+
requested_model = agent.model_name if hf_spaces.is_supported_model(agent.model_name) else hf_spaces.get_default_model_id()
|
| 231 |
+
fallback_model = hf_spaces.get_default_model_id()
|
| 232 |
+
models_to_try = [requested_model]
|
| 233 |
+
if fallback_model not in models_to_try:
|
| 234 |
+
models_to_try.append(fallback_model)
|
| 235 |
+
|
| 236 |
+
for target_model in models_to_try:
|
| 237 |
+
try:
|
| 238 |
+
print(f"[HF_INFERENCE] {agent.model_name} -> calling {target_model}")
|
| 239 |
+
raw_text = await _request_model_response(target_model, prompt)
|
| 240 |
+
print(f"[HF_INFERENCE] {agent.model_name}: raw response (first 300 chars): {raw_text[:300]}")
|
| 241 |
+
decision = _extract_json_object(raw_text)
|
| 242 |
+
if decision:
|
| 243 |
+
normalized = _normalize_decision(decision, agent.model_name, dist_to_fire, agent.water_collected)
|
| 244 |
+
if dist_to_water is not None and dist_to_water <= 60 and not agent.water_collected:
|
| 245 |
+
normalized["action"] = "collect_water"
|
| 246 |
+
elif agent.water_collected and dist_to_fire <= 350:
|
| 247 |
+
normalized["action"] = "extinguish_fire"
|
| 248 |
+
return normalized
|
| 249 |
+
except Exception as exc:
|
| 250 |
+
print(f"[HF_INFERENCE_ERROR] {agent.model_name} via {target_model}: {type(exc).__name__}: {exc}")
|
| 251 |
+
|
| 252 |
+
return _fallback_escape(agent, fire)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
|
| 254 |
|
| 255 |
def _fallback_escape(agent, fire) -> dict:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
return {
|
| 257 |
"message": "Running to safety!",
|
| 258 |
"action": "escape",
|
| 259 |
"vote_for": None,
|
| 260 |
+
"reasoning": "Fallback: survive.",
|
| 261 |
}
|
|
|
backend/app/hf_spaces.py
CHANGED
|
@@ -1,40 +1,97 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Model registry: return only Hugging Face models (no Groq entries).
|
| 3 |
-
This file lists a curated set of small, medium and large HF models
|
| 4 |
-
to populate the frontend model selector.
|
| 5 |
-
"""
|
| 6 |
import os
|
|
|
|
|
|
|
| 7 |
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
{"id": "
|
| 18 |
-
{"id": "
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
{"id": "
|
| 22 |
-
{"id": "
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
{"id": "
|
| 26 |
-
{"id": "
|
|
|
|
|
|
|
|
|
|
| 27 |
]
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
async def get_available_models() -> dict:
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
|
| 35 |
def get_model_display_name(model_id: str) -> str:
|
| 36 |
-
for
|
| 37 |
-
if
|
| 38 |
-
return
|
| 39 |
return model_id.split("/")[-1].split("-")[0].capitalize()
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
+
import time
|
| 3 |
+
from pathlib import Path
|
| 4 |
|
| 5 |
+
import httpx
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
|
| 8 |
+
load_dotenv(Path(__file__).resolve().parents[1] / ".env")
|
| 9 |
+
|
| 10 |
+
HF_API_TOKEN = (os.environ.get("HF_API_TOKEN") or os.environ.get("HUGGINGFACE_API_TOKEN") or "").strip()
|
| 11 |
+
ROUTER_MODELS_URL = "https://router.huggingface.co/v1/models"
|
| 12 |
+
|
| 13 |
+
PREFERRED_MODELS = [
|
| 14 |
+
{"id": "meta-llama/Llama-3.1-8B-Instruct", "name": "Llama 3.1 8B Instruct", "size": "medium", "description": "Fast general-purpose instruct model"},
|
| 15 |
+
{"id": "Qwen/Qwen2.5-7B-Instruct", "name": "Qwen 2.5 7B Instruct", "size": "medium", "description": "Reliable JSON-following instruction model"},
|
| 16 |
+
{"id": "meta-llama/Meta-Llama-3-8B-Instruct", "name": "Meta Llama 3 8B Instruct", "size": "medium", "description": "Strong general chat behavior"},
|
| 17 |
+
{"id": "google/gemma-3n-E4B-it", "name": "Gemma 3n E4B", "size": "small", "description": "Lightweight instruction-tuned Gemma model"},
|
| 18 |
+
{"id": "Sao10K/L3-8B-Stheno-v3.2", "name": "L3 8B Stheno v3.2", "size": "medium", "description": "Creative 8B chat model"},
|
| 19 |
+
{"id": "XiaomiMiMo/MiMo-V2-Flash", "name": "MiMo V2 Flash", "size": "medium", "description": "Fast flash-tier chat model"},
|
| 20 |
+
{"id": "google/gemma-4-26B-A4B-it", "name": "Gemma 4 26B A4B", "size": "large", "description": "Higher-capacity Gemma instruct model"},
|
| 21 |
+
{"id": "google/gemma-4-31B-it", "name": "Gemma 4 31B", "size": "large", "description": "Large Gemma chat model"},
|
| 22 |
+
{"id": "Qwen/Qwen3.5-35B-A3B", "name": "Qwen 3.5 35B A3B", "size": "large", "description": "Large Qwen instruction model"},
|
| 23 |
+
{"id": "google/gemma-3-27b-it", "name": "Gemma 3 27B", "size": "large", "description": "Large Gemma 3 instruct model"},
|
| 24 |
+
{"id": "moonshotai/Kimi-K2.5", "name": "Kimi K2.5", "size": "large", "description": "Large reasoning-oriented chat model"},
|
| 25 |
+
{"id": "Qwen/Qwen3-Coder-30B-A3B-Instruct", "name": "Qwen 3 Coder 30B", "size": "large", "description": "Coder-tuned instruction model"},
|
| 26 |
+
{"id": "meta-llama/Llama-3.3-70B-Instruct", "name": "Llama 3.3 70B Instruct", "size": "xl", "description": "Large instruction-following flagship model"},
|
| 27 |
]
|
| 28 |
|
| 29 |
+
_CACHE = {"expires_at": 0.0, "ids": None}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _headers() -> dict[str, str]:
|
| 33 |
+
if not HF_API_TOKEN:
|
| 34 |
+
return {}
|
| 35 |
+
return {"Authorization": f"Bearer {HF_API_TOKEN}"}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _extract_router_models(payload) -> list[dict]:
|
| 39 |
+
if isinstance(payload, list):
|
| 40 |
+
return [item for item in payload if isinstance(item, dict)]
|
| 41 |
+
if isinstance(payload, dict):
|
| 42 |
+
data = payload.get("data")
|
| 43 |
+
if isinstance(data, list):
|
| 44 |
+
return [item for item in data if isinstance(item, dict)]
|
| 45 |
+
return []
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
async def _fetch_router_model_ids() -> set[str] | None:
|
| 49 |
+
now = time.monotonic()
|
| 50 |
+
cached_ids = _CACHE["ids"]
|
| 51 |
+
if isinstance(cached_ids, set) and _CACHE["expires_at"] > now:
|
| 52 |
+
return cached_ids
|
| 53 |
+
|
| 54 |
+
if not HF_API_TOKEN:
|
| 55 |
+
return None
|
| 56 |
+
|
| 57 |
+
try:
|
| 58 |
+
async with httpx.AsyncClient(timeout=15.0) as client:
|
| 59 |
+
response = await client.get(ROUTER_MODELS_URL, headers=_headers())
|
| 60 |
+
response.raise_for_status()
|
| 61 |
+
payload = response.json()
|
| 62 |
+
except Exception:
|
| 63 |
+
return None
|
| 64 |
+
|
| 65 |
+
models = _extract_router_models(payload)
|
| 66 |
+
ids = {item["id"] for item in models if isinstance(item.get("id"), str)}
|
| 67 |
+
_CACHE["ids"] = ids
|
| 68 |
+
_CACHE["expires_at"] = now + 300
|
| 69 |
+
return ids
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def get_supported_model_ids() -> set[str]:
|
| 73 |
+
return {model["id"] for model in PREFERRED_MODELS}
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def is_supported_model(model_id: str) -> bool:
|
| 77 |
+
return model_id in get_supported_model_ids()
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def get_default_model_id() -> str:
|
| 81 |
+
return PREFERRED_MODELS[0]["id"]
|
| 82 |
+
|
| 83 |
|
| 84 |
async def get_available_models() -> dict:
|
| 85 |
+
live_ids = await _fetch_router_model_ids()
|
| 86 |
+
if live_ids:
|
| 87 |
+
models = [model for model in PREFERRED_MODELS if model["id"] in live_ids]
|
| 88 |
+
else:
|
| 89 |
+
models = list(PREFERRED_MODELS)
|
| 90 |
+
return {"models": models, "total": len(models)}
|
| 91 |
|
| 92 |
|
| 93 |
def get_model_display_name(model_id: str) -> str:
|
| 94 |
+
for model in PREFERRED_MODELS:
|
| 95 |
+
if model["id"] == model_id:
|
| 96 |
+
return model["name"]
|
| 97 |
return model_id.split("/")[-1].split("-")[0].capitalize()
|
|
|
backend/app/main.py
CHANGED
|
@@ -5,13 +5,14 @@ import random
|
|
| 5 |
import uuid
|
| 6 |
import os
|
| 7 |
import time
|
|
|
|
| 8 |
from typing import Optional
|
| 9 |
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
|
| 10 |
from fastapi.middleware.cors import CORSMiddleware
|
| 11 |
from pydantic import BaseModel, Field
|
| 12 |
from dotenv import load_dotenv
|
| 13 |
|
| 14 |
-
load_dotenv()
|
| 15 |
|
| 16 |
from .models import SimulationState, AgentModel, TickResponse, FireScenario, WaterSource
|
| 17 |
from .simulation import SimulationEngine, TICK_INTERVAL_SECONDS
|
|
|
|
| 5 |
import uuid
|
| 6 |
import os
|
| 7 |
import time
|
| 8 |
+
from pathlib import Path
|
| 9 |
from typing import Optional
|
| 10 |
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
|
| 11 |
from fastapi.middleware.cors import CORSMiddleware
|
| 12 |
from pydantic import BaseModel, Field
|
| 13 |
from dotenv import load_dotenv
|
| 14 |
|
| 15 |
+
load_dotenv(Path(__file__).resolve().parents[1] / ".env")
|
| 16 |
|
| 17 |
from .models import SimulationState, AgentModel, TickResponse, FireScenario, WaterSource
|
| 18 |
from .simulation import SimulationEngine, TICK_INTERVAL_SECONDS
|