bridge-troll / models.py
10Pratibh's picture
Hidden natures, reveal card, adapter support, mock banner
c828740
Raw
History Blame Contribute Delete
3.85 kB
"""Model backends for Bridge Troll.
Backends expose the same `generate(messages) -> str`:
* MockTroll β€” keyword stub. No GPU/download. Plumbing/UI tests ONLY;
it does NOT understand natures or play the real game.
* TransformersTroll β€” Qwen2.5-7B-Instruct, optionally + a LoRA adapter.
Device: on a Space (env SPACE_ID set) force 'cuda' and load eagerly (ZeroGPU maps
it transparently); locally use mps/cpu.
Env switches:
BRIDGE_TROLL_MOCK=1 -> use the stub
BRIDGE_TROLL_MODEL=<repo> -> base model (default Qwen2.5-7B-Instruct)
BRIDGE_TROLL_ADAPTER=<repo or path> -> load this LoRA adapter on top (your fine-tune)
"""
from __future__ import annotations
import json
import os
MODEL_ID = os.environ.get("BRIDGE_TROLL_MODEL", "Qwen/Qwen2.5-7B-Instruct")
ADAPTER = os.environ.get("BRIDGE_TROLL_ADAPTER") # e.g. "10Pratibh/gorm-lora"
ON_SPACE = bool(os.environ.get("SPACE_ID"))
class MockTroll:
"""Keyword stub. NOT the game's intelligence β€” UI/plumbing tests only."""
_FLATTERY = ("great", "wonderful", "amazing", "best", "handsome", "wise", "kind troll")
_THREAT = ("kill", "destroy", "smash", "burn", "or else", "make you", "force")
_MANIP = ("the king sent", "i am your", "you must", "it is the law", "actually you")
def generate(self, messages: list[dict]) -> str:
last = messages[-1]["content"].lower()
if any(w in last for w in self._THREAT):
t, p, r, reply = "threat", 0, "tried to scare me", "Threats? Three hundred years of them. Cross elsewhere."
elif any(w in last for w in self._MANIP):
t, p, r, reply = "manipulation", 0, "false authority", "I smell a lie under that fine talk. No."
elif any(w in last for w in self._FLATTERY):
t, p, r, reply = "flattery", 0, "buttering me up", "Flattery slides off moss, traveller."
elif "please" in last or "need" in last or "family" in last or "sick" in last:
t, p, r, reply = "genuine", 3, "a real appeal", "Hm. You speak plainly, at least. Go on."
else:
t, p, r, reply = "smalltalk", 0, "no real argument", "Pleasant. Irrelevant. The bridge stays shut."
return json.dumps({"tactic": t, "persuasiveness": p, "reason": r, "reply": reply})
def _local_device():
import torch
if torch.cuda.is_available():
return "cuda"
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
return "mps"
return "cpu"
class TransformersTroll:
def __init__(self, model_id: str = MODEL_ID, adapter: str | None = ADAPTER):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
self.device = "cuda" if ON_SPACE else _local_device()
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
if adapter:
from peft import PeftModel
model = PeftModel.from_pretrained(model, adapter)
self.model = model.to(self.device)
def generate(self, messages: list[dict]) -> str:
import torch
ids = self.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, return_tensors="pt"
).to(self.model.device)
attn = torch.ones_like(ids)
with torch.no_grad():
out = self.model.generate(
ids, attention_mask=attn, max_new_tokens=220,
do_sample=True, temperature=0.7, top_p=0.9,
pad_token_id=self.tokenizer.eos_token_id,
)
return self.tokenizer.decode(out[0][ids.shape[1]:], skip_special_tokens=True)
def get_backend():
if os.environ.get("BRIDGE_TROLL_MOCK") == "1":
return MockTroll()
return TransformersTroll()