Intellite / app.py
ProCreations's picture
Call demo.queue() at module level so HF Spaces launcher picks it up
94e711e
"""intellite 100M β€” RLHF data collector served as a Gradio HuggingFace Space.
Every assistant reply gets πŸ‘ / πŸ‘Ž buttons. When the user rates a reply,
the (system, prior messages, response, liked) tuple is appended to a
local JSONL file, and a CommitScheduler pushes that folder to a dataset
repo on the Hub every 5 minutes.
Environment variables:
INTELLITE_CKPT path to SFT checkpoint (default: ./best.pt)
HF_TOKEN HF access token with *write* scope on the dataset
repo (REQUIRED β€” set as a Space secret)
FEEDBACK_REPO dataset repo id (default: ProCreations/Intellite-storage)
"""
import json
import os
import sys
import threading
import time
import traceback
import uuid
from pathlib import Path
import gradio as gr
import tiktoken
import torch
from huggingface_hub import CommitScheduler
SPACE_DIR = Path(__file__).resolve().parent
sys.path.insert(0, str(SPACE_DIR))
from config import ModelConfig
from model import IntelliteGPT
# ------------------------------------------------------------------------
# Paths & constants
CKPT_PATH = Path(os.environ.get("INTELLITE_CKPT", SPACE_DIR / "best.pt"))
FEEDBACK_DIR = SPACE_DIR / "user_feedback"
FEEDBACK_DIR.mkdir(exist_ok=True)
# Unique filename per replica/restart so concurrent Spaces don't clobber.
FEEDBACK_FILE = FEEDBACK_DIR / f"data_{uuid.uuid4().hex}.jsonl"
FEEDBACK_REPO = os.environ.get("FEEDBACK_REPO", "ProCreations/Intellite-storage")
HF_TOKEN = os.environ.get("HF_TOKEN")
DEFAULT_SYSTEM = "You are a helpful, honest, and concise assistant."
SYSTEM_TAG = "<|system|>\n"
USER_TAG = "<|user|>\n"
ASST_TAG = "<|assistant|>\n"
STOP_MARKERS = ("<|user|>", "<|system|>")
# ------------------------------------------------------------------------
# Model load (once, at startup)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[sys] device={DEVICE} ckpt={CKPT_PATH}")
if not CKPT_PATH.exists():
raise FileNotFoundError(
f"No checkpoint at {CKPT_PATH}. Upload your SFT best.pt to the Space "
f"root, or set the INTELLITE_CKPT environment variable to its path."
)
sd = torch.load(str(CKPT_PATH), map_location=DEVICE)
_fields = ModelConfig.__dataclass_fields__.keys()
MCFG = ModelConfig(**{k: v for k, v in sd["model_cfg"].items() if k in _fields})
MODEL = IntelliteGPT(MCFG).to(DEVICE)
MODEL.load_state_dict(sd["model"])
MODEL.eval()
TOKENS_SEEN = int(sd.get("tokens_seen", 0))
BEST_VAL = float(sd.get("best_val", float("nan")))
ENC = tiktoken.get_encoding("gpt2")
EOT = ENC.eot_token
N_PARAMS = MODEL.num_params()
print(f"[model] {N_PARAMS/1e6:.1f}M params tokens_seen={TOKENS_SEEN:,} best_val={BEST_VAL:.4f}")
# ------------------------------------------------------------------------
# Hub sync β€” CommitScheduler pushes FEEDBACK_DIR to the dataset every 5 min.
if HF_TOKEN:
scheduler = CommitScheduler(
repo_id=FEEDBACK_REPO,
repo_type="dataset",
folder_path=FEEDBACK_DIR,
path_in_repo="data",
every=5,
token=HF_TOKEN,
)
print(f"[hub] scheduler active β†’ {FEEDBACK_REPO} (every 5 min)")
else:
scheduler = None
print("[hub] HF_TOKEN not set β€” feedback will stay local only")
# ------------------------------------------------------------------------
# Prompt templating + generation (mirrors chat.py)
def render_prompt_ids(system: str, prior_messages: list[dict], user_msg: str) -> list[int]:
"""Encode the SFT chat template exactly as sft_prepare.py did."""
ids: list[int] = []
if system:
ids.extend(ENC.encode_ordinary(SYSTEM_TAG + system.strip() + "\n"))
pending_user = None
for m in prior_messages:
role = m.get("role")
content = (m.get("content") or "").strip()
if role == "user":
pending_user = content
elif role == "assistant" and pending_user is not None:
ids.extend(ENC.encode_ordinary(USER_TAG + pending_user + "\n"))
ids.extend(ENC.encode_ordinary(ASST_TAG))
ids.extend(ENC.encode_ordinary(content))
ids.append(EOT)
pending_user = None
ids.extend(ENC.encode_ordinary(USER_TAG + user_msg.strip() + "\n"))
ids.extend(ENC.encode_ordinary(ASST_TAG))
return ids
@torch.no_grad()
def stream_reply(prompt_ids, max_new, temperature, top_k, top_p, rep_penalty):
"""Yield the partial assistant reply after each new token."""
x = torch.tensor([prompt_ids], dtype=torch.long, device=DEVICE)
ctx = MCFG.seq_len
start = len(prompt_ids)
reply = ""
for _ in range(max_new):
xc = x[:, -ctx:]
if DEVICE == "cuda":
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
logits, _ = MODEL(xc)
else:
logits, _ = MODEL(xc)
logits = logits[0, -1, :].float()
if rep_penalty and rep_penalty != 1.0:
seen = torch.unique(x[0])
prev = logits[seen]
logits[seen] = torch.where(prev > 0, prev / rep_penalty, prev * rep_penalty)
logits = logits / max(temperature, 1e-5)
if top_k and top_k > 0:
k = min(int(top_k), logits.numel())
v, _ = torch.topk(logits, k)
logits[logits < v[-1]] = -float("inf")
if top_p and 0.0 < top_p < 1.0:
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
cum = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
mask = cum > top_p
mask[1:] = mask[:-1].clone()
mask[0] = False
logits[sorted_idx[mask]] = -float("inf")
probs = torch.softmax(logits, dim=-1)
next_tok = torch.multinomial(probs, num_samples=1)
tok_id = int(next_tok.item())
x = torch.cat([x, next_tok.unsqueeze(0)], dim=1)
if tok_id == EOT:
break
reply = ENC.decode(x[0, start:].tolist())
while reply.endswith("\ufffd"):
reply = reply[:-1]
hit_stop = False
for marker in STOP_MARKERS:
idx = reply.find(marker)
if idx != -1:
reply = reply[:idx]
hit_stop = True
break
if hit_stop:
break
yield reply.strip()
yield reply.strip()
# ------------------------------------------------------------------------
# Feedback store β€” JSONL, append-only, synced to Hub by CommitScheduler.
_local_lock = threading.Lock()
_local_count = {"total": 0, "liked": 0}
def _count_jsonl_lines(path: Path) -> tuple[int, int]:
total, liked = 0, 0
if not path.exists():
return 0, 0
with path.open() as f:
for line in f:
line = line.strip()
if not line:
continue
total += 1
try:
if json.loads(line).get("liked"):
liked += 1
except json.JSONDecodeError:
pass
return total, liked
t, l = _count_jsonl_lines(FEEDBACK_FILE)
_local_count["total"], _local_count["liked"] = t, l
def _stats_str() -> str:
t = _local_count["total"]
l = _local_count["liked"]
repo_link = f"[`{FEEDBACK_REPO}`](https://huggingface.co/datasets/{FEEDBACK_REPO})"
sync = "synced every 5 min" if scheduler else "**HF_TOKEN missing β€” not syncing**"
return (
f"**{t}** records this session Β· πŸ‘ {l} Β· πŸ‘Ž {t - l} \n"
f"Pushed to {repo_link} ({sync})"
)
def save_feedback(evt: gr.LikeData, history: list, system: str) -> str:
"""Handle a thumbs-up / thumbs-down click on a chat message."""
if evt.liked is None:
return "rating cleared (nothing saved)"
idx = evt.index[0] if isinstance(evt.index, (list, tuple)) else evt.index
if not isinstance(idx, int) or idx < 0 or idx >= len(history):
return f"bad index {evt.index!r}"
msg = history[idx]
if msg.get("role") != "assistant":
return "skipped non-assistant message"
record = {
"ts": time.strftime("%Y-%m-%dT%H:%M:%S"),
"system": (system or DEFAULT_SYSTEM).strip(),
"prompt_messages": history[:idx],
"response": msg.get("content", ""),
"liked": bool(evt.liked),
}
# Write under the scheduler's lock (or our own) so the background push
# never sees a half-written line.
lock = scheduler.lock if scheduler else _local_lock
with lock:
with FEEDBACK_FILE.open("a") as f:
f.write(json.dumps(record, ensure_ascii=False) + "\n")
_local_count["total"] += 1
if record["liked"]:
_local_count["liked"] += 1
verdict = "πŸ‘ good" if evt.liked else "πŸ‘Ž bad"
return f"saved {verdict} Β· {_local_count['total']} this session"
# ------------------------------------------------------------------------
# Chat callback
def chat(user_msg, history, system, max_new, temperature, top_k, top_p, rep_penalty):
"""Stream a reply; yield updated chatbot history after each token."""
user_msg = (user_msg or "").strip()
if not user_msg:
yield history, ""
return
history = list(history) + [
{"role": "user", "content": user_msg},
{"role": "assistant", "content": ""},
]
prior = history[:-2]
ids = render_prompt_ids(system or DEFAULT_SYSTEM, prior, user_msg)
room = MCFG.seq_len - int(max_new)
if len(ids) > room > 0:
ids = ids[-room:]
try:
for partial in stream_reply(ids, int(max_new), float(temperature),
int(top_k), float(top_p), float(rep_penalty)):
history[-1]["content"] = partial
yield history, ""
except Exception:
history[-1]["content"] = f"[error] {traceback.format_exc()}"
yield history, ""
# ------------------------------------------------------------------------
# UI
with gr.Blocks(
title="intellite 100M β€” RLHF collector",
theme=gr.themes.Soft(primary_hue="indigo", neutral_hue="slate"),
) as demo:
gr.Markdown(
f"# intellite 100M β€” RLHF data collector\n"
f"{MCFG.d_model}d Γ— {MCFG.n_layers}L Γ— {MCFG.n_heads}h "
f"({N_PARAMS/1e6:.1f}M params) Β· {TOKENS_SEEN/1e6:.0f}M SFT tokens Β· "
f"val_loss {BEST_VAL:.3f} Β· device `{DEVICE}` \n"
f"**Please rate every response with πŸ‘ or πŸ‘Ž.** Ratings auto-sync to "
f"[`{FEEDBACK_REPO}`](https://huggingface.co/datasets/{FEEDBACK_REPO}) "
f"every 5 minutes for RLHF training."
)
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(
type="messages",
height=520,
show_copy_button=True,
avatar_images=(None, None),
)
msg = gr.Textbox(
placeholder="Your message β€” Enter to send",
lines=2,
show_label=False,
autofocus=True,
)
with gr.Row():
send_btn = gr.Button("Send", variant="primary")
clear_btn = gr.Button("Clear chat")
feedback_status = gr.Markdown("_rate replies with πŸ‘ / πŸ‘Ž_")
with gr.Column(scale=1):
system = gr.Textbox(value=DEFAULT_SYSTEM, label="System prompt", lines=3)
max_new = gr.Slider(16, 800, value=400, step=16, label="max new tokens")
temp = gr.Slider(0.1, 1.5, value=0.7, step=0.05, label="temperature")
top_k = gr.Slider(0, 200, value=50, step=1, label="top-k (0 = off)")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top-p")
rep = gr.Slider(1.0, 1.5, value=1.1, step=0.01, label="repetition penalty")
gr.Markdown("### RLHF data")
stats_md = gr.Markdown(_stats_str())
send_btn.click(
chat,
inputs=[msg, chatbot, system, max_new, temp, top_k, top_p, rep],
outputs=[chatbot, msg],
)
msg.submit(
chat,
inputs=[msg, chatbot, system, max_new, temp, top_k, top_p, rep],
outputs=[chatbot, msg],
)
clear_btn.click(lambda: [], None, chatbot, queue=False)
chatbot.like(
save_feedback,
inputs=[chatbot, system],
outputs=[feedback_status],
).then(lambda: _stats_str(), None, stats_md, queue=False)
demo.queue()
if __name__ == "__main__":
demo.launch()