File size: 6,915 Bytes
7fe39f3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 | """Model backends.
Three interchangeable backends behind one tiny interface:
backend.chat(system: str, user: str) -> str
- `transformers` : load the small model locally (default; GPU or CPU).
- `inference_api` : call the Hugging Face serverless Inference API (no GPU).
- `mock` : a deterministic fake that emits valid tagged output, so the
parser, engine and UI can be tested with no weights / network.
Pick with the MICRORPG_BACKEND env var. See README for all knobs.
"""
from __future__ import annotations
import os
import random
from typing import Protocol
DEFAULT_MODEL = os.environ.get("MICRORPG_MODEL", "Qwen/Qwen3-4B-Instruct-2507")
MAX_NEW_TOKENS = int(os.environ.get("MICRORPG_MAX_TOKENS", "512"))
class Backend(Protocol):
name: str
def chat(self, system: str, user: str) -> str: ...
# --------------------------------------------------------------------------- #
# transformers (local)
# --------------------------------------------------------------------------- #
class TransformersBackend:
name = "transformers"
def __init__(self, model_id: str = DEFAULT_MODEL):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
self.model_id = model_id
adapter = os.environ.get("MICRORPG_ADAPTER") # fine-tuned LoRA dir, optional
# If an adapter is given, the tokenizer was saved alongside it (and may carry
# the right chat template) — prefer it; otherwise load the base tokenizer.
self.tokenizer = AutoTokenizer.from_pretrained(adapter or model_id)
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None,
)
if adapter:
from peft import PeftModel
self.model = PeftModel.from_pretrained(self.model, adapter)
print(f"[llm] loaded fine-tuned adapter: {adapter}")
self._torch = torch
def chat(self, system: str, user: str) -> str:
messages = [
{"role": "system", "content": system},
{"role": "user", "content": user},
]
inputs = self.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, return_tensors="pt"
).to(self.model.device)
with self._torch.no_grad():
out = self.model.generate(
inputs,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=True,
temperature=0.8,
top_p=0.9,
repetition_penalty=1.1,
pad_token_id=self.tokenizer.eos_token_id,
)
text = self.tokenizer.decode(
out[0][inputs.shape[-1]:], skip_special_tokens=True
)
return text.strip()
# --------------------------------------------------------------------------- #
# Hugging Face Inference API (serverless, no local GPU)
# --------------------------------------------------------------------------- #
class InferenceAPIBackend:
name = "inference_api"
def __init__(self, model_id: str = DEFAULT_MODEL):
from huggingface_hub import InferenceClient
token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
self.model_id = model_id
self.client = InferenceClient(model=model_id, token=token)
def chat(self, system: str, user: str) -> str:
resp = self.client.chat_completion(
messages=[
{"role": "system", "content": system},
{"role": "user", "content": user},
],
max_tokens=MAX_NEW_TOKENS,
temperature=0.8,
top_p=0.9,
)
return resp.choices[0].message.content.strip()
# --------------------------------------------------------------------------- #
# mock (no weights, no network) — emits valid tagged output
# --------------------------------------------------------------------------- #
class MockBackend:
"""Deterministic-ish fake model. It reads the action out of the user message
and produces a plausible tagged turn so the rest of the stack can be exercised
end-to-end without any model. Not smart — just well-formed."""
name = "mock"
_SCENES = [
("A cold wind drags mist across {loc}. Something shifts in the dark ahead.",
"ENEMY: Mist Wraith|hp=10|atk=3"),
("You find a leather pouch half-buried in the mud. Coins glint inside.",
"GOLD: +7"),
("An old hermit beckons you toward a flickering lantern.",
"NPC: Aldric|hermit|friendly|knows the old roads"),
("A rusted chest yields a glimmer of steel.",
"ITEM_ADD: Iron Shortsword"),
("The path opens onto a ruined chapel, its bell long silent.",
"LOCATION: The Ruined Chapel"),
]
def __init__(self, model_id: str = "mock"):
self.model_id = model_id
self._rng = random.Random(7)
def chat(self, system: str, user: str) -> str:
action = user.lower()
loc = "the crossroads"
for line in user.splitlines():
if line.lower().startswith("location:"):
loc = line.split(":", 1)[1].strip()
# Combat-aware: if the player attacks, hurt the enemy and take a hit back.
if "in combat" in action and any(
w in action for w in ("attack", "strike", "hit", "swing", "stab")
):
narrative = "You lunge forward and your blade bites home; the creature shrieks and claws back."
state = "ENEMY_HP: -6\nHP: -3\nXP: +4"
choices = ["1. Press the attack.", "2. Back away and guard.", "3. Try to flee."]
else:
scene, change = self._rng.choice(self._SCENES)
narrative = scene.format(loc=loc)
state = change
choices = ["1. Investigate closely.", "2. Move on carefully.", "3. Call out."]
return (
f"<narrative>\n{narrative}\n</narrative>\n"
f"<state>\n{state}\n</state>\n"
f"<choices>\n" + "\n".join(choices) + "\n</choices>"
)
# --------------------------------------------------------------------------- #
# factory
# --------------------------------------------------------------------------- #
def build_backend(kind: str | None = None, model_id: str | None = None) -> Backend:
kind = (kind or os.environ.get("MICRORPG_BACKEND", "transformers")).lower()
model_id = model_id or DEFAULT_MODEL
if kind == "mock":
return MockBackend()
if kind in ("inference_api", "api", "inference"):
return InferenceAPIBackend(model_id)
if kind in ("transformers", "local"):
return TransformersBackend(model_id)
raise ValueError(f"Unknown backend: {kind!r}")
|