"""intellite 100M โ€” RLHF data collector served as a Gradio HuggingFace Space. Every assistant reply gets ๐Ÿ‘ / ๐Ÿ‘Ž buttons. When the user rates a reply, the (system, prior messages, response, liked) tuple is appended to a local JSONL file, and a CommitScheduler pushes that folder to a dataset repo on the Hub every 5 minutes. Environment variables: INTELLITE_CKPT path to SFT checkpoint (default: ./best.pt) HF_TOKEN HF access token with *write* scope on the dataset repo (REQUIRED โ€” set as a Space secret) FEEDBACK_REPO dataset repo id (default: ProCreations/Intellite-storage) """ import json import os import sys import threading import time import traceback import uuid from pathlib import Path import gradio as gr import tiktoken import torch from huggingface_hub import CommitScheduler SPACE_DIR = Path(__file__).resolve().parent sys.path.insert(0, str(SPACE_DIR)) from config import ModelConfig from model import IntelliteGPT # ------------------------------------------------------------------------ # Paths & constants CKPT_PATH = Path(os.environ.get("INTELLITE_CKPT", SPACE_DIR / "best.pt")) FEEDBACK_DIR = SPACE_DIR / "user_feedback" FEEDBACK_DIR.mkdir(exist_ok=True) # Unique filename per replica/restart so concurrent Spaces don't clobber. FEEDBACK_FILE = FEEDBACK_DIR / f"data_{uuid.uuid4().hex}.jsonl" FEEDBACK_REPO = os.environ.get("FEEDBACK_REPO", "ProCreations/Intellite-storage") HF_TOKEN = os.environ.get("HF_TOKEN") DEFAULT_SYSTEM = "You are a helpful, honest, and concise assistant." SYSTEM_TAG = "<|system|>\n" USER_TAG = "<|user|>\n" ASST_TAG = "<|assistant|>\n" STOP_MARKERS = ("<|user|>", "<|system|>") # ------------------------------------------------------------------------ # Model load (once, at startup) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print(f"[sys] device={DEVICE} ckpt={CKPT_PATH}") if not CKPT_PATH.exists(): raise FileNotFoundError( f"No checkpoint at {CKPT_PATH}. Upload your SFT best.pt to the Space " f"root, or set the INTELLITE_CKPT environment variable to its path." ) sd = torch.load(str(CKPT_PATH), map_location=DEVICE) _fields = ModelConfig.__dataclass_fields__.keys() MCFG = ModelConfig(**{k: v for k, v in sd["model_cfg"].items() if k in _fields}) MODEL = IntelliteGPT(MCFG).to(DEVICE) MODEL.load_state_dict(sd["model"]) MODEL.eval() TOKENS_SEEN = int(sd.get("tokens_seen", 0)) BEST_VAL = float(sd.get("best_val", float("nan"))) ENC = tiktoken.get_encoding("gpt2") EOT = ENC.eot_token N_PARAMS = MODEL.num_params() print(f"[model] {N_PARAMS/1e6:.1f}M params tokens_seen={TOKENS_SEEN:,} best_val={BEST_VAL:.4f}") # ------------------------------------------------------------------------ # Hub sync โ€” CommitScheduler pushes FEEDBACK_DIR to the dataset every 5 min. if HF_TOKEN: scheduler = CommitScheduler( repo_id=FEEDBACK_REPO, repo_type="dataset", folder_path=FEEDBACK_DIR, path_in_repo="data", every=5, token=HF_TOKEN, ) print(f"[hub] scheduler active โ†’ {FEEDBACK_REPO} (every 5 min)") else: scheduler = None print("[hub] HF_TOKEN not set โ€” feedback will stay local only") # ------------------------------------------------------------------------ # Prompt templating + generation (mirrors chat.py) def render_prompt_ids(system: str, prior_messages: list[dict], user_msg: str) -> list[int]: """Encode the SFT chat template exactly as sft_prepare.py did.""" ids: list[int] = [] if system: ids.extend(ENC.encode_ordinary(SYSTEM_TAG + system.strip() + "\n")) pending_user = None for m in prior_messages: role = m.get("role") content = (m.get("content") or "").strip() if role == "user": pending_user = content elif role == "assistant" and pending_user is not None: ids.extend(ENC.encode_ordinary(USER_TAG + pending_user + "\n")) ids.extend(ENC.encode_ordinary(ASST_TAG)) ids.extend(ENC.encode_ordinary(content)) ids.append(EOT) pending_user = None ids.extend(ENC.encode_ordinary(USER_TAG + user_msg.strip() + "\n")) ids.extend(ENC.encode_ordinary(ASST_TAG)) return ids @torch.no_grad() def stream_reply(prompt_ids, max_new, temperature, top_k, top_p, rep_penalty): """Yield the partial assistant reply after each new token.""" x = torch.tensor([prompt_ids], dtype=torch.long, device=DEVICE) ctx = MCFG.seq_len start = len(prompt_ids) reply = "" for _ in range(max_new): xc = x[:, -ctx:] if DEVICE == "cuda": with torch.autocast(device_type="cuda", dtype=torch.bfloat16): logits, _ = MODEL(xc) else: logits, _ = MODEL(xc) logits = logits[0, -1, :].float() if rep_penalty and rep_penalty != 1.0: seen = torch.unique(x[0]) prev = logits[seen] logits[seen] = torch.where(prev > 0, prev / rep_penalty, prev * rep_penalty) logits = logits / max(temperature, 1e-5) if top_k and top_k > 0: k = min(int(top_k), logits.numel()) v, _ = torch.topk(logits, k) logits[logits < v[-1]] = -float("inf") if top_p and 0.0 < top_p < 1.0: sorted_logits, sorted_idx = torch.sort(logits, descending=True) cum = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) mask = cum > top_p mask[1:] = mask[:-1].clone() mask[0] = False logits[sorted_idx[mask]] = -float("inf") probs = torch.softmax(logits, dim=-1) next_tok = torch.multinomial(probs, num_samples=1) tok_id = int(next_tok.item()) x = torch.cat([x, next_tok.unsqueeze(0)], dim=1) if tok_id == EOT: break reply = ENC.decode(x[0, start:].tolist()) while reply.endswith("\ufffd"): reply = reply[:-1] hit_stop = False for marker in STOP_MARKERS: idx = reply.find(marker) if idx != -1: reply = reply[:idx] hit_stop = True break if hit_stop: break yield reply.strip() yield reply.strip() # ------------------------------------------------------------------------ # Feedback store โ€” JSONL, append-only, synced to Hub by CommitScheduler. _local_lock = threading.Lock() _local_count = {"total": 0, "liked": 0} def _count_jsonl_lines(path: Path) -> tuple[int, int]: total, liked = 0, 0 if not path.exists(): return 0, 0 with path.open() as f: for line in f: line = line.strip() if not line: continue total += 1 try: if json.loads(line).get("liked"): liked += 1 except json.JSONDecodeError: pass return total, liked t, l = _count_jsonl_lines(FEEDBACK_FILE) _local_count["total"], _local_count["liked"] = t, l def _stats_str() -> str: t = _local_count["total"] l = _local_count["liked"] repo_link = f"[`{FEEDBACK_REPO}`](https://huggingface.co/datasets/{FEEDBACK_REPO})" sync = "synced every 5 min" if scheduler else "**HF_TOKEN missing โ€” not syncing**" return ( f"**{t}** records this session ยท ๐Ÿ‘ {l} ยท ๐Ÿ‘Ž {t - l} \n" f"Pushed to {repo_link} ({sync})" ) def save_feedback(evt: gr.LikeData, history: list, system: str) -> str: """Handle a thumbs-up / thumbs-down click on a chat message.""" if evt.liked is None: return "rating cleared (nothing saved)" idx = evt.index[0] if isinstance(evt.index, (list, tuple)) else evt.index if not isinstance(idx, int) or idx < 0 or idx >= len(history): return f"bad index {evt.index!r}" msg = history[idx] if msg.get("role") != "assistant": return "skipped non-assistant message" record = { "ts": time.strftime("%Y-%m-%dT%H:%M:%S"), "system": (system or DEFAULT_SYSTEM).strip(), "prompt_messages": history[:idx], "response": msg.get("content", ""), "liked": bool(evt.liked), } # Write under the scheduler's lock (or our own) so the background push # never sees a half-written line. lock = scheduler.lock if scheduler else _local_lock with lock: with FEEDBACK_FILE.open("a") as f: f.write(json.dumps(record, ensure_ascii=False) + "\n") _local_count["total"] += 1 if record["liked"]: _local_count["liked"] += 1 verdict = "๐Ÿ‘ good" if evt.liked else "๐Ÿ‘Ž bad" return f"saved {verdict} ยท {_local_count['total']} this session" # ------------------------------------------------------------------------ # Chat callback def chat(user_msg, history, system, max_new, temperature, top_k, top_p, rep_penalty): """Stream a reply; yield updated chatbot history after each token.""" user_msg = (user_msg or "").strip() if not user_msg: yield history, "" return history = list(history) + [ {"role": "user", "content": user_msg}, {"role": "assistant", "content": ""}, ] prior = history[:-2] ids = render_prompt_ids(system or DEFAULT_SYSTEM, prior, user_msg) room = MCFG.seq_len - int(max_new) if len(ids) > room > 0: ids = ids[-room:] try: for partial in stream_reply(ids, int(max_new), float(temperature), int(top_k), float(top_p), float(rep_penalty)): history[-1]["content"] = partial yield history, "" except Exception: history[-1]["content"] = f"[error] {traceback.format_exc()}" yield history, "" # ------------------------------------------------------------------------ # UI with gr.Blocks( title="intellite 100M โ€” RLHF collector", theme=gr.themes.Soft(primary_hue="indigo", neutral_hue="slate"), ) as demo: gr.Markdown( f"# intellite 100M โ€” RLHF data collector\n" f"{MCFG.d_model}d ร— {MCFG.n_layers}L ร— {MCFG.n_heads}h " f"({N_PARAMS/1e6:.1f}M params) ยท {TOKENS_SEEN/1e6:.0f}M SFT tokens ยท " f"val_loss {BEST_VAL:.3f} ยท device `{DEVICE}` \n" f"**Please rate every response with ๐Ÿ‘ or ๐Ÿ‘Ž.** Ratings auto-sync to " f"[`{FEEDBACK_REPO}`](https://huggingface.co/datasets/{FEEDBACK_REPO}) " f"every 5 minutes for RLHF training." ) with gr.Row(): with gr.Column(scale=3): chatbot = gr.Chatbot( type="messages", height=520, show_copy_button=True, avatar_images=(None, None), ) msg = gr.Textbox( placeholder="Your message โ€” Enter to send", lines=2, show_label=False, autofocus=True, ) with gr.Row(): send_btn = gr.Button("Send", variant="primary") clear_btn = gr.Button("Clear chat") feedback_status = gr.Markdown("_rate replies with ๐Ÿ‘ / ๐Ÿ‘Ž_") with gr.Column(scale=1): system = gr.Textbox(value=DEFAULT_SYSTEM, label="System prompt", lines=3) max_new = gr.Slider(16, 800, value=400, step=16, label="max new tokens") temp = gr.Slider(0.1, 1.5, value=0.7, step=0.05, label="temperature") top_k = gr.Slider(0, 200, value=50, step=1, label="top-k (0 = off)") top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top-p") rep = gr.Slider(1.0, 1.5, value=1.1, step=0.01, label="repetition penalty") gr.Markdown("### RLHF data") stats_md = gr.Markdown(_stats_str()) send_btn.click( chat, inputs=[msg, chatbot, system, max_new, temp, top_k, top_p, rep], outputs=[chatbot, msg], ) msg.submit( chat, inputs=[msg, chatbot, system, max_new, temp, top_k, top_p, rep], outputs=[chatbot, msg], ) clear_btn.click(lambda: [], None, chatbot, queue=False) chatbot.like( save_feedback, inputs=[chatbot, system], outputs=[feedback_status], ).then(lambda: _stats_str(), None, stats_md, queue=False) demo.queue() if __name__ == "__main__": demo.launch()