File size: 6,915 Bytes
7fe39f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
"""Model backends.

Three interchangeable backends behind one tiny interface:

    backend.chat(system: str, user: str) -> str

- `transformers`    : load the small model locally (default; GPU or CPU).
- `inference_api`   : call the Hugging Face serverless Inference API (no GPU).
- `mock`            : a deterministic fake that emits valid tagged output, so the
                      parser, engine and UI can be tested with no weights / network.

Pick with the MICRORPG_BACKEND env var. See README for all knobs.
"""

from __future__ import annotations

import os
import random
from typing import Protocol


DEFAULT_MODEL = os.environ.get("MICRORPG_MODEL", "Qwen/Qwen3-4B-Instruct-2507")
MAX_NEW_TOKENS = int(os.environ.get("MICRORPG_MAX_TOKENS", "512"))


class Backend(Protocol):
    name: str

    def chat(self, system: str, user: str) -> str: ...


# --------------------------------------------------------------------------- #
# transformers (local)
# --------------------------------------------------------------------------- #
class TransformersBackend:
    name = "transformers"

    def __init__(self, model_id: str = DEFAULT_MODEL):
        import torch
        from transformers import AutoModelForCausalLM, AutoTokenizer

        self.model_id = model_id
        adapter = os.environ.get("MICRORPG_ADAPTER")  # fine-tuned LoRA dir, optional

        # If an adapter is given, the tokenizer was saved alongside it (and may carry
        # the right chat template) — prefer it; otherwise load the base tokenizer.
        self.tokenizer = AutoTokenizer.from_pretrained(adapter or model_id)
        dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
        self.model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=dtype,
            device_map="auto" if torch.cuda.is_available() else None,
        )
        if adapter:
            from peft import PeftModel
            self.model = PeftModel.from_pretrained(self.model, adapter)
            print(f"[llm] loaded fine-tuned adapter: {adapter}")
        self._torch = torch

    def chat(self, system: str, user: str) -> str:
        messages = [
            {"role": "system", "content": system},
            {"role": "user", "content": user},
        ]
        inputs = self.tokenizer.apply_chat_template(
            messages, add_generation_prompt=True, return_tensors="pt"
        ).to(self.model.device)

        with self._torch.no_grad():
            out = self.model.generate(
                inputs,
                max_new_tokens=MAX_NEW_TOKENS,
                do_sample=True,
                temperature=0.8,
                top_p=0.9,
                repetition_penalty=1.1,
                pad_token_id=self.tokenizer.eos_token_id,
            )
        text = self.tokenizer.decode(
            out[0][inputs.shape[-1]:], skip_special_tokens=True
        )
        return text.strip()


# --------------------------------------------------------------------------- #
# Hugging Face Inference API (serverless, no local GPU)
# --------------------------------------------------------------------------- #
class InferenceAPIBackend:
    name = "inference_api"

    def __init__(self, model_id: str = DEFAULT_MODEL):
        from huggingface_hub import InferenceClient

        token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
        self.model_id = model_id
        self.client = InferenceClient(model=model_id, token=token)

    def chat(self, system: str, user: str) -> str:
        resp = self.client.chat_completion(
            messages=[
                {"role": "system", "content": system},
                {"role": "user", "content": user},
            ],
            max_tokens=MAX_NEW_TOKENS,
            temperature=0.8,
            top_p=0.9,
        )
        return resp.choices[0].message.content.strip()


# --------------------------------------------------------------------------- #
# mock (no weights, no network) — emits valid tagged output
# --------------------------------------------------------------------------- #
class MockBackend:
    """Deterministic-ish fake model. It reads the action out of the user message
    and produces a plausible tagged turn so the rest of the stack can be exercised
    end-to-end without any model. Not smart — just well-formed."""

    name = "mock"

    _SCENES = [
        ("A cold wind drags mist across {loc}. Something shifts in the dark ahead.",
         "ENEMY: Mist Wraith|hp=10|atk=3"),
        ("You find a leather pouch half-buried in the mud. Coins glint inside.",
         "GOLD: +7"),
        ("An old hermit beckons you toward a flickering lantern.",
         "NPC: Aldric|hermit|friendly|knows the old roads"),
        ("A rusted chest yields a glimmer of steel.",
         "ITEM_ADD: Iron Shortsword"),
        ("The path opens onto a ruined chapel, its bell long silent.",
         "LOCATION: The Ruined Chapel"),
    ]

    def __init__(self, model_id: str = "mock"):
        self.model_id = model_id
        self._rng = random.Random(7)

    def chat(self, system: str, user: str) -> str:
        action = user.lower()
        loc = "the crossroads"
        for line in user.splitlines():
            if line.lower().startswith("location:"):
                loc = line.split(":", 1)[1].strip()

        # Combat-aware: if the player attacks, hurt the enemy and take a hit back.
        if "in combat" in action and any(
            w in action for w in ("attack", "strike", "hit", "swing", "stab")
        ):
            narrative = "You lunge forward and your blade bites home; the creature shrieks and claws back."
            state = "ENEMY_HP: -6\nHP: -3\nXP: +4"
            choices = ["1. Press the attack.", "2. Back away and guard.", "3. Try to flee."]
        else:
            scene, change = self._rng.choice(self._SCENES)
            narrative = scene.format(loc=loc)
            state = change
            choices = ["1. Investigate closely.", "2. Move on carefully.", "3. Call out."]

        return (
            f"<narrative>\n{narrative}\n</narrative>\n"
            f"<state>\n{state}\n</state>\n"
            f"<choices>\n" + "\n".join(choices) + "\n</choices>"
        )


# --------------------------------------------------------------------------- #
# factory
# --------------------------------------------------------------------------- #
def build_backend(kind: str | None = None, model_id: str | None = None) -> Backend:
    kind = (kind or os.environ.get("MICRORPG_BACKEND", "transformers")).lower()
    model_id = model_id or DEFAULT_MODEL

    if kind == "mock":
        return MockBackend()
    if kind in ("inference_api", "api", "inference"):
        return InferenceAPIBackend(model_id)
    if kind in ("transformers", "local"):
        return TransformersBackend(model_id)
    raise ValueError(f"Unknown backend: {kind!r}")