Longshot / app.py
kobinasam's picture
initial-commit
cb88544 verified
Raw
History Blame Contribute Delete
16.2 kB
"""
Longshot
========
A writing game played inside a language model's predictions. You build a sentence one
word at a time. At each step you see the model's real top candidate next-words with
their actual probabilities, and you pick one. Your score is the total SURPRISE you
create, measured in bits (-log2 of the probability you chose). Pick what the model
expects and you score nothing; pick its longshots and the sentence turns strange.
Track 2 toy for the "Small Models Big Adventures" hackathon.
Model: openbmb/MiniCPM5-1B (1B, on-device, well under the 32B ceiling).
Why a SMALL model is the point: the game IS the next-token distribution, so it needs
the model's real logits, and a 1B model is fast enough to predict after every single
word AND quirky enough to make fighting its expectations genuinely fun. One forward
pass per word, no GPU required.
Modes: model (real logits) / keeper (a mock distribution, so it always demos).
"""
import os
import math
import html
import inspect
import random
import gradio as gr
_BLOCKS_HAS_CSS = "css" in inspect.signature(gr.Blocks.__init__).parameters
_LAUNCH_HAS_SSR = "ssr_mode" in inspect.signature(gr.Blocks.launch).parameters
_TB_HAS_COPY = "show_copy_button" in inspect.signature(gr.Textbox.__init__).parameters
_COPY = {"show_copy_button": True} if _TB_HAS_COPY else {}
MODEL_ID = os.environ.get("LONGSHOT_MODEL", "openbmb/MiniCPM5-1B")
DEBUG = os.environ.get("LONGSHOT_DEBUG", "").strip().lower() in {"1", "true", "yes"}
MAX_STEPS = int(os.environ.get("LONGSHOT_MAX_STEPS", "24"))
K = 6 # candidate words shown per step
SEEDS = ["The old lighthouse keeper", "On the third day of the voyage",
"Nobody expected the cheese to", "She opened the box and found"]
_PUNCT = set(".,;:!?\u2014\u2026\"')(-")
def _bits(p):
return -math.log2(max(p, 1e-9))
def label(bits):
if bits < 18:
return "Predictable"
if bits < 42:
return "Curious"
if bits < 80:
return "Eccentric"
if bits < 140:
return "Unhinged"
return "Beyond comprehension"
# ----------------------------------------------------------------- model -------
def _noop_gpu(*a, **k):
def wrap(fn):
return fn
return wrap(a[0]) if a and callable(a[0]) else wrap
if os.environ.get("SPACES_ZERO_GPU", "").lower() in {"true", "1"}:
try:
import spaces
GPU = spaces.GPU
except Exception: # noqa: BLE001
GPU = _noop_gpu
else:
GPU = _noop_gpu
_tokenizer = None
_model = None
MODE = "keeper"
def load_model():
global _tokenizer, _model, MODE
try:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, torch_dtype=torch.float32, trust_remote_code=True)
_model.eval()
MODE = "model"
print(f"[Longshot] Loaded {MODEL_ID} -- model mode (real logits).")
except Exception as exc: # noqa: BLE001
MODE = "keeper"
print(f"[Longshot] Could not load {MODEL_ID} ({exc}). Keeper mode active.")
load_model()
@GPU(duration=30)
def _predict_model(text, k=K):
import torch
if torch.cuda.is_available() and next(_model.parameters()).device.type != "cuda":
_model.to("cuda", dtype=torch.bfloat16)
enc = _tokenizer(text if text else _tokenizer.bos_token or " ",
return_tensors="pt").to(_model.device)
with torch.no_grad():
logits = _model(**enc).logits[0, -1].float()
probs = torch.softmax(logits, dim=-1)
topv, topi = probs.topk(60)
cands, fallback = [], []
for p, tid in zip(topv.tolist(), topi.tolist()):
piece = _tokenizer.decode([int(tid)])
if not piece or "\ufffd" in piece:
continue
st = piece.strip()
if not st:
continue
fallback.append((piece, st, p))
if piece[:1] == " " or st[0] in _PUNCT or not text:
cands.append({"piece": piece, "word": st, "p": p})
if len(cands) >= k:
break
if len(cands) < 3: # tokenizer gave mostly sub-words
cands = [{"piece": pc if pc[:1] == " " else " " + s, "word": s, "p": p}
for pc, s, p in fallback[:k]]
return cands
_BANK = ["the", "a", "and", "then", "slowly", "suddenly", "quietly", "because",
"moon", "door", "whispered", "forgot", "umbrella", "sea", "clockwork",
"nobody", "almost", "perhaps", "velvet", "thunder", "sang", "river",
"glass", "remembered", "blue", "forever", "maybe", "stone", "feather", "drifted"]
_KP = [0.30, 0.20, 0.14, 0.10, 0.07, 0.05]
def _keeper_candidates(text, k=K):
rng = random.Random((text[-40:] or "seed"))
words = rng.sample(_BANK, min(k, len(_BANK)))
return [{"piece": " " + w, "word": w, "p": _KP[i % len(_KP)]} for i, w in enumerate(words)]
def next_candidates(text):
if MODE == "model":
try:
cands = _predict_model(text)
if cands:
return cands
except Exception as exc: # noqa: BLE001
print(f"[Longshot] predict error: {exc}")
return _keeper_candidates(text)
# --------------------------------------------------------------- rendering -----
def esc(s):
return html.escape(str(s))
def base_state():
return {"text": "", "score": 0.0, "steps": 0, "history": [],
"started": False, "finished": False, "candidates": []}
def share_text(state):
return (f'{state["score"]:.0f} bits of surprise ({label(state["score"])}) - '
f'"{state["text"].strip()}" · played in Longshot')
def render_stage(state):
if not state.get("started"):
return ('<div class="stage empty">Pick an opening below to begin, then build a '
'sentence the model never saw coming.</div>')
text = state.get("text", "")
cursor = "" if state.get("finished") else '<span class="cursor">\u258c</span>'
story = f'<div class="story">{esc(text)}{cursor}</div>'
lab = label(state["score"])
score = (f'<div class="score">Surprise <b>{state["score"]:.1f} bits</b>'
f'<span class="lab lab-{lab.split()[0].lower()}">{lab}</span></div>')
if state.get("finished"):
body = (f'<div class="final">The oracle is <b>{lab.lower()}</b>. '
f'{state["steps"]} words, {state["score"]:.1f} bits of surprise.</div>')
return f'<div class="stage">{story}{score}{body}</div>'
rows = ""
for idx, c in enumerate(state.get("candidates") or []):
p = c["p"]
pct = p * 100
w = max(3, min(100, round((p ** 0.5) * 100)))
exp = ' <span class="exp">expected</span>' if idx == 0 else ''
rows += (f'<div class="cand"><span class="cw">{esc(c["word"])}</span>'
f'<span class="bar"><span style="width:{w}%"></span></span>'
f'<span class="cp">{pct:.1f}%</span>{exp}</div>')
ask = ('<div class="ask">Choose the next word below. The longer the bar, the more '
'the model expects it. The bolder your longshot, the more surprise you score.</div>')
return f'<div class="stage">{story}{score}{ask}<div class="cands">{rows}</div></div>'
def view(state):
finished = state.get("finished")
cands = state.get("candidates") or []
btns = []
for i in range(K):
if not finished and state.get("started") and i < len(cands):
btns.append(gr.update(value=(cands[i]["word"][:24] or "\u00b7"), visible=True))
else:
btns.append(gr.update(visible=False))
share = (gr.update(value=share_text(state), visible=True) if finished
else gr.update(value="", visible=False))
return [render_stage(state)] + btns + [share, state]
# --------------------------------------------------------------- the game ------
def _commit(state, c):
state["text"] += c["piece"]
state["score"] += _bits(c["p"])
state["steps"] += 1
state["history"].append((c["word"], c["p"]))
if state["steps"] >= MAX_STEPS:
state["finished"] = True
state["candidates"] = []
else:
state["candidates"] = next_candidates(state["text"])
return state
def on_begin(seed):
seed = (seed or "").strip()
if not seed:
return view(base_state())
state = {"text": seed[:200], "score": 0.0, "steps": 0, "history": [],
"started": True, "finished": False, "candidates": []}
state["candidates"] = next_candidates(state["text"])
return view(state)
def on_pick(i, state):
state = state or base_state()
if not state.get("started") or state.get("finished"):
return view(state)
cands = state.get("candidates") or []
if not isinstance(i, int) or i < 0 or i >= len(cands):
return view(state)
return view(_commit(state, cands[i]))
def on_longshot(state):
state = state or base_state()
cands = state.get("candidates") or []
if not state.get("started") or state.get("finished") or not cands:
return view(state)
least = min(range(len(cands)), key=lambda j: cands[j]["p"])
return view(_commit(state, cands[least]))
def on_oracle(state):
state = state or base_state()
cands = state.get("candidates") or []
if not state.get("started") or state.get("finished") or not cands:
return view(state)
most = max(range(len(cands)), key=lambda j: cands[j]["p"])
return view(_commit(state, cands[most]))
def on_finish(state):
state = state or base_state()
if state.get("started"):
state["finished"] = True
state["candidates"] = []
return view(state)
def on_new():
return view(base_state())
CSS = """
@import url('https://fonts.googleapis.com/css2?family=Fraunces:ital,opsz,wght@0,9..144,400;0,9..144,600;1,9..144,400&family=Spectral:ital,wght@0,400;0,500;1,400&display=swap');
:root{--paper:#f4eee0;--paper-2:#ece3cf;--ink:#2b2a2e;--ink-soft:#6a6470;
--violet:#5b4b8a;--rose:#a8456b;--line:#cdbfa6;--gold:#b4892f;}
.gradio-container,.gradio-container.dark,.dark{
--body-background-fill:transparent;--background-fill-primary:#fffcf4;--background-fill-secondary:#f3ecda;
--block-background-fill:#fffcf4;--block-border-color:var(--line);--border-color-primary:var(--line);
--body-text-color:var(--ink);--body-text-color-subdued:var(--ink-soft);
--block-label-text-color:var(--violet);--block-title-text-color:var(--ink);
--block-label-background-fill:#ece3cf;--block-title-background-fill:transparent;
--input-background-fill:#fffcf4;--input-border-color:var(--line);--input-placeholder-color:var(--ink-soft);
--button-primary-background-fill:var(--violet);--button-primary-background-fill-hover:#493b70;
--button-primary-text-color:#fbf7ec;--button-primary-border-color:#493b70;
--button-secondary-background-fill:#ece3cf;--button-secondary-background-fill-hover:#e2d4b6;
--button-secondary-text-color:var(--ink);--button-secondary-border-color:var(--line);
--color-accent:var(--rose);--color-accent-soft:#f3dbe4;}
.gradio-container{background:radial-gradient(120% 80% at 80% -10%,#fbf5e6,var(--paper) 55%,var(--paper-2));
font-family:'Spectral',Georgia,serif !important;color:var(--ink) !important;max-width:880px !important;}
.gradio-container textarea,.gradio-container input[type="text"],.gradio-container input:not([type]){
background:#fffcf4 !important;color:var(--ink) !important;-webkit-text-fill-color:var(--ink) !important;border-color:var(--line) !important;}
.gradio-container textarea::placeholder,.gradio-container input::placeholder{color:var(--ink-soft) !important;-webkit-text-fill-color:var(--ink-soft) !important;opacity:1;}
.ls-title{font-family:'Fraunces',serif;font-weight:600;font-size:2.6rem;line-height:1;margin:.2rem 0 0;}
.ls-title em{font-style:italic;color:var(--violet);}
.ls-sub{font-style:italic;color:var(--ink-soft);margin:.35rem 0 1rem;font-size:1.05rem;}
.ls-mode{display:inline-block;font-size:.72rem;letter-spacing:.12em;text-transform:uppercase;color:var(--violet);border:1px solid var(--line);border-radius:999px;padding:.15rem .6rem;}
.stage{background:#fffcf4;border:1px solid var(--line);border-radius:16px;padding:20px 22px;box-shadow:0 16px 40px -26px rgba(43,42,46,.7);}
.stage.empty{color:var(--ink-soft);font-style:italic;text-align:center;border-style:dashed;}
.story{font-family:'Fraunces',serif;font-size:1.5rem;line-height:1.5;color:var(--ink);}
.cursor{color:var(--rose);animation:blink 1s steps(2) infinite;font-weight:600;}
@keyframes blink{50%{opacity:0;}}
.score{margin:14px 0 4px;color:var(--ink-soft);}
.score b{color:var(--ink);font-family:'Fraunces',serif;}
.lab{margin-left:10px;font-size:.75rem;letter-spacing:.1em;text-transform:uppercase;border:1px solid var(--line);border-radius:999px;padding:.12rem .55rem;color:var(--violet);}
.lab-eccentric,.lab-unhinged,.lab-beyond{color:var(--rose);border-color:var(--rose);}
.ask{color:var(--ink-soft);font-size:.92rem;margin:10px 0 12px;}
.cands{display:flex;flex-direction:column;gap:7px;}
.cand{display:flex;align-items:center;gap:10px;}
.cand .cw{font-family:'Fraunces',serif;font-weight:600;min-width:120px;color:var(--ink);}
.cand .bar{flex:1;height:12px;background:#efe6d2;border-radius:999px;overflow:hidden;}
.cand .bar span{display:block;height:100%;background:linear-gradient(90deg,var(--violet),var(--rose));}
.cand .cp{min-width:54px;text-align:right;color:var(--ink-soft);font-size:.9rem;}
.cand .exp{font-size:.68rem;letter-spacing:.08em;text-transform:uppercase;color:var(--gold);}
.final{margin-top:10px;font-family:'Fraunces',serif;font-size:1.15rem;color:var(--ink);}
.ls-foot{color:var(--ink-soft);font-size:.82rem;font-style:italic;text-align:center;margin-top:12px;}
footer{display:none !important;}
"""
_bk = {"title": "Longshot"}
if _BLOCKS_HAS_CSS:
_bk["css"] = CSS
_bk["theme"] = gr.themes.Soft()
with gr.Blocks(**_bk) as demo:
state = gr.State(base_state())
public = "MiniCPM5-1B · real logits" if MODE == "model" else "Longshot"
mode_label = f"{public} · [{MODE}]" if DEBUG else public
gr.HTML(f"""
<div><div class="ls-title">Long<em>shot</em></div>
<div class="ls-sub">Build a sentence from a model's own predictions. The stranger your path, the higher your score.</div>
<span class="ls-mode">{mode_label}</span></div>""")
with gr.Row():
seed = gr.Textbox(placeholder="an opening line... (The old lighthouse keeper)",
show_label=False, scale=8, autofocus=True)
begin = gr.Button("Begin", variant="primary", scale=2)
with gr.Row():
seed_btns = [gr.Button(s, size="sm") for s in SEEDS]
stage = gr.HTML(render_stage(base_state()))
with gr.Row():
pick_btns = [gr.Button("\u00b7", visible=False) for _ in range(K)]
with gr.Row():
longshot = gr.Button("🎲 Take the longshot", size="sm")
oracle = gr.Button("✨ Oracle's pick", size="sm")
finish = gr.Button("Finish the sentence", size="sm")
newgame = gr.Button("New game", size="sm")
share = gr.Textbox(label="Share your sentence", visible=False, interactive=True, **_COPY)
gr.HTML('<div class="ls-foot">Surprise is measured in bits: choosing a word the model gave a 1-in-4 chance scores 2 bits; a 1-in-100 longshot scores about 6.6.</div>')
OUT = [stage] + pick_btns + [share, state]
begin.click(on_begin, seed, OUT)
seed.submit(on_begin, seed, OUT)
for sb, s in zip(seed_btns, SEEDS):
sb.click(on_begin, gr.State(s), OUT)
for idx, pb in enumerate(pick_btns):
pb.click(on_pick, [gr.State(idx), state], OUT)
longshot.click(on_longshot, state, OUT)
oracle.click(on_oracle, state, OUT)
finish.click(on_finish, state, OUT)
newgame.click(on_new, None, OUT)
if __name__ == "__main__":
_lk = {}
if not _BLOCKS_HAS_CSS:
_lk["css"] = CSS
_lk["theme"] = gr.themes.Soft()
if _LAUNCH_HAS_SSR:
_lk["ssr_mode"] = False
demo.queue(max_size=24).launch(**_lk)