saneowl's picture
Remove hardcoded NVIDIA API key - now requires env var
abaee9a
"""
models.py
─────────────────────────────────────────────────────────────────────────────
Handles LLM inference for the Survival Island AI agent.
Priority chain
--------------
1. NVIDIA API (GLM-4.7 with thinking mode) - DEFAULT
2. Local transformers pipeline (available on GPU Spaces)
3. HuggingFace Inference API (works on CPU Spaces, needs HF_TOKEN)
4. Rule-based fallback (no network / no token)
Environment variables
---------------------
NVIDIA_API_KEY - NVIDIA API key (for GLM-4.7 with thinking)
HF_TOKEN - HuggingFace API token (required for path 3)
HF_MODEL - Model repo ID, e.g. "mistralai/Mistral-7B-Instruct-v0.2"
"""
from __future__ import annotations
import json
import logging
import os
import re
import sys
import threading
from typing import Optional
import requests
logger = logging.getLogger(__name__)
# ── Config ────────────────────────────────────────────────────────────────────
NVIDIA_API_KEY: str = os.getenv("NVIDIA_API_KEY") or ""
NVIDIA_BASE_URL: str = "https://integrate.api.nvidia.com/v1"
NVIDIA_MODEL: str = "z-ai/glm4.7"
HF_TOKEN: str = os.getenv("HF_TOKEN", "")
HF_MODEL: str = os.getenv("HF_MODEL", "mistralai/Mistral-7B-Instruct-v0.2")
VALID_ACTIONS: set[str] = {
"FORAGE", "HUNT", "FISH", "GET_WATER", "SEEK_SHELTER",
"BUILD_CAMP", "UPGRADE_CAMP", "CRAFT_SPEAR", "CRAFT_BOW",
"CRAFT_ROD", "CRAFT_BOAT", "EVACUATE", "FIGHT", "FLEE", "WANDER",
}
# ── Singleton ─────────────────────────────────────────────────────────────────
_lock = threading.Lock()
_pipeline = None # transformers Pipeline object (optional)
_use_local: bool = False # True once local pipeline loaded successfully
# ── NVIDIA API (GLM-4.7 with thinking mode) ──────────────────────────────
def infer_nvidia(prompt: str) -> dict:
"""Call NVIDIA API with GLM-4.7 model (supports thinking mode)."""
try:
from openai import OpenAI
except ImportError:
raise RuntimeError("openai package not installed. Run: pip install openai")
client = OpenAI(
base_url=NVIDIA_BASE_URL,
api_key=NVIDIA_API_KEY
)
logger.info("[models] Calling NVIDIA API (GLM-4.7) with thinking mode...")
response = client.chat.completions.create(
model=NVIDIA_MODEL,
messages=[{"role": "user", "content": prompt}],
temperature=0.7,
top_p=1,
max_tokens=1024,
extra_body={"chat_template_kwargs": {"enable_thinking": True, "clear_thinking": False}},
)
content = response.choices[0].message.content
reasoning = getattr(response.choices[0].message, "reasoning_content", None)
if reasoning:
logger.info(f"[models] Thinking: {reasoning[:200]}...")
logger.info(f"[models] Response: {content[:200]}...")
return _parse_action(content)
# ── Local pipeline (optional - GPU Spaces) ────────────────────────────────────
def _try_load_local() -> bool:
"""Attempts to load the model locally with transformers + torch."""
global _pipeline, _use_local
try:
import torch
from transformers import pipeline as hf_pipeline
logger.info(f"[models] Loading {HF_MODEL} locally ...")
device = 0 if torch.cuda.is_available() else -1
_pipeline = hf_pipeline(
"text-generation",
model=HF_MODEL,
token=HF_TOKEN or None,
device=device,
torch_dtype=torch.float16 if device >= 0 else torch.float32,
max_new_tokens=80,
)
_use_local = True
logger.info(f"[models] Local pipeline ready (device={device})")
return True
except Exception as exc:
logger.warning(f"[models] Local pipeline skipped: {exc}")
return False
def get_pipeline():
"""Return the pipeline, initialising on first call."""
with _lock:
if _pipeline is None:
_try_load_local()
return _pipeline
# ── Output parser ─────────────────────────────────────────────────────────────
def _parse_action(raw: str) -> dict:
"""Extract {"action": ..., "thought": ...} from model output."""
cleaned = re.sub(r"```(?:json)?|```", "", raw, flags=re.IGNORECASE)
match = re.search(r"\{[^}]+\}", cleaned, re.DOTALL)
if not match:
raise ValueError(f"No JSON object in model output: {raw!r}")
obj = json.loads(match.group())
action = str(obj.get("action", "WANDER")).upper()
if action not in VALID_ACTIONS:
logger.warning(f"[models] Unknown action '{action}', defaulting to WANDER")
action = "WANDER"
return {
"action": action,
"thought": str(obj.get("thought", "Processing...")),
}
# ── Inference paths ───────────────────────────────────────────────────────────
def infer_local(prompt: str) -> dict:
"""Run generation using the locally loaded transformers pipeline."""
pipe = get_pipeline()
if pipe is None:
raise RuntimeError("Local pipeline not available")
outputs = pipe(
prompt,
max_new_tokens=80,
temperature=0.7,
do_sample=True,
return_full_text=False,
)
return _parse_action(outputs[0]["generated_text"])
def infer_api(prompt: str) -> dict:
"""Call the HuggingFace Inference API (remote, no GPU needed)."""
if not HF_TOKEN:
raise RuntimeError("HF_TOKEN not set - cannot call Inference API")
url = f"https://api-inference.huggingface.co/models/{HF_MODEL}"
resp = requests.post(
url,
headers={
"Authorization": f"Bearer {HF_TOKEN}",
"Content-Type": "application/json",
},
json={
"inputs": prompt,
"parameters": {
"max_new_tokens": 80,
"temperature": 0.7,
"return_full_text": False,
"stop": ["\n\n", "</s>", "[INST]"],
},
},
timeout=30,
)
resp.raise_for_status()
data = resp.json()
raw = (
data[0]["generated_text"]
if isinstance(data, list)
else data.get("generated_text", "")
)
return _parse_action(raw)
def run_inference(prompt: str) -> dict:
"""Main entry point used by server/app.py.
Tries NVIDIA API -> local -> HF API -> rule-based fallback in that order.
Never raises - always returns a valid {"action", "thought"} dict.
"""
try:
logger.info("[models] Trying NVIDIA API inference...")
return infer_nvidia(prompt)
except Exception as exc:
logger.warning(f"[models] NVIDIA API inference failed: {type(exc).__name__}: {exc}")
if _use_local:
try:
logger.info("[models] Trying local inference...")
return infer_local(prompt)
except Exception as exc:
logger.warning(f"[models] Local inference failed: {type(exc).__name__}: {exc}")
if HF_TOKEN:
try:
logger.info("[models] Trying API inference...")
return infer_api(prompt)
except Exception as exc:
logger.warning(f"[models] API inference failed: {type(exc).__name__}: {exc}")
logger.error("[models] All inference paths failed - returning WANDER")
return {"action": "WANDER", "thought": "Inference offline. Wandering."}