suraj140's picture
Clean repository root for HF Spaces
5c40041
"""
models.py
─────────────────────────────────────────────────────────────────────────────
Handles LLM inference for the Survival Island AI agent.
Priority chain
--------------
1. Local transformers pipeline (available on GPU Spaces – uncomment deps)
2. HuggingFace Inference API (works on CPU Spaces, needs HF_TOKEN)
3. Rule-based fallback (no network / no token)
Environment variables
---------------------
HF_TOKEN – HuggingFace API token (required for path 2)
HF_MODEL – Model repo ID, e.g. "mistralai/Mistral-7B-Instruct-v0.2"
"""
from __future__ import annotations
import json
import logging
import os
import re
import threading
from typing import Optional
import requests
logger = logging.getLogger(__name__)
# ── Config ────────────────────────────────────────────────────────────────────
HF_TOKEN: str = os.getenv("HF_TOKEN", "")
HF_MODEL: str = os.getenv("HF_MODEL", "mistralai/Mistral-7B-Instruct-v0.2")
VALID_ACTIONS: set[str] = {
"FORAGE", "HUNT", "FISH", "GET_WATER", "SEEK_SHELTER",
"BUILD_CAMP", "UPGRADE_CAMP", "CRAFT_SPEAR", "CRAFT_BOW",
"CRAFT_ROD", "CRAFT_BOAT", "EVACUATE", "FIGHT", "FLEE", "WANDER",
}
# ── Singleton ─────────────────────────────────────────────────────────────────
_lock = threading.Lock()
_pipeline = None # transformers Pipeline object (optional)
_use_local: bool = False # True once local pipeline loaded successfully
# ── Local pipeline (optional – GPU Spaces) ────────────────────────────────────
def _try_load_local() -> bool:
"""
Attempts to load the model locally with transformers + torch.
Returns True on success. Skipped silently on CPU-only environments.
Uncomment torch/transformers in requirements.txt to enable this path.
"""
global _pipeline, _use_local
try:
import torch
from transformers import pipeline as hf_pipeline # type: ignore
logger.info(f"[models] Loading {HF_MODEL} locally …")
device = 0 if torch.cuda.is_available() else -1
_pipeline = hf_pipeline(
"text-generation",
model=HF_MODEL,
token=HF_TOKEN or None,
device=device,
torch_dtype=torch.float16 if device >= 0 else torch.float32,
max_new_tokens=80,
)
_use_local = True
logger.info(f"[models] Local pipeline ready (device={device})")
return True
except Exception as exc:
logger.warning(f"[models] Local pipeline skipped: {exc}")
return False
def get_pipeline():
"""Return the pipeline, initialising on first call."""
with _lock:
if _pipeline is None:
_try_load_local()
return _pipeline
# ── Output parser ─────────────────────────────────────────────────────────────
def _parse_action(raw: str) -> dict:
"""
Extract {"action": ..., "thought": ...} from model output.
Handles markdown fences, leading prose, trailing noise.
"""
# Strip markdown fences
cleaned = re.sub(r"```(?:json)?|```", "", raw, flags=re.IGNORECASE)
# Keep only the first JSON object
match = re.search(r"\{[^}]+\}", cleaned, re.DOTALL)
if not match:
raise ValueError(f"No JSON object in model output: {raw!r}")
obj = json.loads(match.group())
action = str(obj.get("action", "WANDER")).upper()
if action not in VALID_ACTIONS:
logger.warning(f"[models] Unknown action '{action}', defaulting to WANDER")
action = "WANDER"
return {
"action": action,
"thought": str(obj.get("thought", "Processing…")),
}
# ── Inference paths ───────────────────────────────────────────────────────────
def infer_local(prompt: str) -> dict:
"""Run generation using the locally loaded transformers pipeline."""
pipe = get_pipeline()
if pipe is None:
raise RuntimeError("Local pipeline not available")
outputs = pipe(
prompt,
max_new_tokens=80,
temperature=0.7,
do_sample=True,
return_full_text=False,
)
return _parse_action(outputs[0]["generated_text"])
def infer_api(prompt: str) -> dict:
"""Call the HuggingFace Inference API (remote, no GPU needed)."""
if not HF_TOKEN:
raise RuntimeError("HF_TOKEN not set – cannot call Inference API")
url = f"https://api-inference.huggingface.co/models/{HF_MODEL}"
resp = requests.post(
url,
headers={
"Authorization": f"Bearer {HF_TOKEN}",
"Content-Type": "application/json",
},
json={
"inputs": prompt,
"parameters": {
"max_new_tokens": 80,
"temperature": 0.7,
"return_full_text": False,
"stop": ["\n\n", "</s>", "[INST]"],
},
},
timeout=30,
)
resp.raise_for_status()
data = resp.json()
raw = (
data[0]["generated_text"]
if isinstance(data, list)
else data.get("generated_text", "")
)
return _parse_action(raw)
def run_inference(prompt: str) -> dict:
"""
Main entry point used by server/app.py.
Tries local β†’ API β†’ rule-based fallback in that order.
Never raises – always returns a valid {"action", "thought"} dict.
"""
# 1. Local transformers pipeline (GPU Space)
if _use_local:
try:
return infer_local(prompt)
except Exception as exc:
logger.warning(f"[models] Local inference failed: {exc}")
# 2. HuggingFace Inference API (CPU Space)
if HF_TOKEN:
try:
return infer_api(prompt)
except Exception as exc:
logger.warning(f"[models] Inference API failed: {exc}")
# 3. Hard fallback
logger.error("[models] All inference paths failed – returning WANDER")
return {"action": "WANDER", "thought": "Inference offline. Wandering."}