Spaces:
Running
Running
| """intellite 100M β RLHF data collector served as a Gradio HuggingFace Space. | |
| Every assistant reply gets π / π buttons. When the user rates a reply, | |
| the (system, prior messages, response, liked) tuple is appended to a | |
| local JSONL file, and a CommitScheduler pushes that folder to a dataset | |
| repo on the Hub every 5 minutes. | |
| Environment variables: | |
| INTELLITE_CKPT path to SFT checkpoint (default: ./best.pt) | |
| HF_TOKEN HF access token with *write* scope on the dataset | |
| repo (REQUIRED β set as a Space secret) | |
| FEEDBACK_REPO dataset repo id (default: ProCreations/Intellite-storage) | |
| """ | |
| import json | |
| import os | |
| import sys | |
| import threading | |
| import time | |
| import traceback | |
| import uuid | |
| from pathlib import Path | |
| import gradio as gr | |
| import tiktoken | |
| import torch | |
| from huggingface_hub import CommitScheduler | |
| SPACE_DIR = Path(__file__).resolve().parent | |
| sys.path.insert(0, str(SPACE_DIR)) | |
| from config import ModelConfig | |
| from model import IntelliteGPT | |
| # ------------------------------------------------------------------------ | |
| # Paths & constants | |
| CKPT_PATH = Path(os.environ.get("INTELLITE_CKPT", SPACE_DIR / "best.pt")) | |
| FEEDBACK_DIR = SPACE_DIR / "user_feedback" | |
| FEEDBACK_DIR.mkdir(exist_ok=True) | |
| # Unique filename per replica/restart so concurrent Spaces don't clobber. | |
| FEEDBACK_FILE = FEEDBACK_DIR / f"data_{uuid.uuid4().hex}.jsonl" | |
| FEEDBACK_REPO = os.environ.get("FEEDBACK_REPO", "ProCreations/Intellite-storage") | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| DEFAULT_SYSTEM = "You are a helpful, honest, and concise assistant." | |
| SYSTEM_TAG = "<|system|>\n" | |
| USER_TAG = "<|user|>\n" | |
| ASST_TAG = "<|assistant|>\n" | |
| STOP_MARKERS = ("<|user|>", "<|system|>") | |
| # ------------------------------------------------------------------------ | |
| # Model load (once, at startup) | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"[sys] device={DEVICE} ckpt={CKPT_PATH}") | |
| if not CKPT_PATH.exists(): | |
| raise FileNotFoundError( | |
| f"No checkpoint at {CKPT_PATH}. Upload your SFT best.pt to the Space " | |
| f"root, or set the INTELLITE_CKPT environment variable to its path." | |
| ) | |
| sd = torch.load(str(CKPT_PATH), map_location=DEVICE) | |
| _fields = ModelConfig.__dataclass_fields__.keys() | |
| MCFG = ModelConfig(**{k: v for k, v in sd["model_cfg"].items() if k in _fields}) | |
| MODEL = IntelliteGPT(MCFG).to(DEVICE) | |
| MODEL.load_state_dict(sd["model"]) | |
| MODEL.eval() | |
| TOKENS_SEEN = int(sd.get("tokens_seen", 0)) | |
| BEST_VAL = float(sd.get("best_val", float("nan"))) | |
| ENC = tiktoken.get_encoding("gpt2") | |
| EOT = ENC.eot_token | |
| N_PARAMS = MODEL.num_params() | |
| print(f"[model] {N_PARAMS/1e6:.1f}M params tokens_seen={TOKENS_SEEN:,} best_val={BEST_VAL:.4f}") | |
| # ------------------------------------------------------------------------ | |
| # Hub sync β CommitScheduler pushes FEEDBACK_DIR to the dataset every 5 min. | |
| if HF_TOKEN: | |
| scheduler = CommitScheduler( | |
| repo_id=FEEDBACK_REPO, | |
| repo_type="dataset", | |
| folder_path=FEEDBACK_DIR, | |
| path_in_repo="data", | |
| every=5, | |
| token=HF_TOKEN, | |
| ) | |
| print(f"[hub] scheduler active β {FEEDBACK_REPO} (every 5 min)") | |
| else: | |
| scheduler = None | |
| print("[hub] HF_TOKEN not set β feedback will stay local only") | |
| # ------------------------------------------------------------------------ | |
| # Prompt templating + generation (mirrors chat.py) | |
| def render_prompt_ids(system: str, prior_messages: list[dict], user_msg: str) -> list[int]: | |
| """Encode the SFT chat template exactly as sft_prepare.py did.""" | |
| ids: list[int] = [] | |
| if system: | |
| ids.extend(ENC.encode_ordinary(SYSTEM_TAG + system.strip() + "\n")) | |
| pending_user = None | |
| for m in prior_messages: | |
| role = m.get("role") | |
| content = (m.get("content") or "").strip() | |
| if role == "user": | |
| pending_user = content | |
| elif role == "assistant" and pending_user is not None: | |
| ids.extend(ENC.encode_ordinary(USER_TAG + pending_user + "\n")) | |
| ids.extend(ENC.encode_ordinary(ASST_TAG)) | |
| ids.extend(ENC.encode_ordinary(content)) | |
| ids.append(EOT) | |
| pending_user = None | |
| ids.extend(ENC.encode_ordinary(USER_TAG + user_msg.strip() + "\n")) | |
| ids.extend(ENC.encode_ordinary(ASST_TAG)) | |
| return ids | |
| def stream_reply(prompt_ids, max_new, temperature, top_k, top_p, rep_penalty): | |
| """Yield the partial assistant reply after each new token.""" | |
| x = torch.tensor([prompt_ids], dtype=torch.long, device=DEVICE) | |
| ctx = MCFG.seq_len | |
| start = len(prompt_ids) | |
| reply = "" | |
| for _ in range(max_new): | |
| xc = x[:, -ctx:] | |
| if DEVICE == "cuda": | |
| with torch.autocast(device_type="cuda", dtype=torch.bfloat16): | |
| logits, _ = MODEL(xc) | |
| else: | |
| logits, _ = MODEL(xc) | |
| logits = logits[0, -1, :].float() | |
| if rep_penalty and rep_penalty != 1.0: | |
| seen = torch.unique(x[0]) | |
| prev = logits[seen] | |
| logits[seen] = torch.where(prev > 0, prev / rep_penalty, prev * rep_penalty) | |
| logits = logits / max(temperature, 1e-5) | |
| if top_k and top_k > 0: | |
| k = min(int(top_k), logits.numel()) | |
| v, _ = torch.topk(logits, k) | |
| logits[logits < v[-1]] = -float("inf") | |
| if top_p and 0.0 < top_p < 1.0: | |
| sorted_logits, sorted_idx = torch.sort(logits, descending=True) | |
| cum = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) | |
| mask = cum > top_p | |
| mask[1:] = mask[:-1].clone() | |
| mask[0] = False | |
| logits[sorted_idx[mask]] = -float("inf") | |
| probs = torch.softmax(logits, dim=-1) | |
| next_tok = torch.multinomial(probs, num_samples=1) | |
| tok_id = int(next_tok.item()) | |
| x = torch.cat([x, next_tok.unsqueeze(0)], dim=1) | |
| if tok_id == EOT: | |
| break | |
| reply = ENC.decode(x[0, start:].tolist()) | |
| while reply.endswith("\ufffd"): | |
| reply = reply[:-1] | |
| hit_stop = False | |
| for marker in STOP_MARKERS: | |
| idx = reply.find(marker) | |
| if idx != -1: | |
| reply = reply[:idx] | |
| hit_stop = True | |
| break | |
| if hit_stop: | |
| break | |
| yield reply.strip() | |
| yield reply.strip() | |
| # ------------------------------------------------------------------------ | |
| # Feedback store β JSONL, append-only, synced to Hub by CommitScheduler. | |
| _local_lock = threading.Lock() | |
| _local_count = {"total": 0, "liked": 0} | |
| def _count_jsonl_lines(path: Path) -> tuple[int, int]: | |
| total, liked = 0, 0 | |
| if not path.exists(): | |
| return 0, 0 | |
| with path.open() as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| total += 1 | |
| try: | |
| if json.loads(line).get("liked"): | |
| liked += 1 | |
| except json.JSONDecodeError: | |
| pass | |
| return total, liked | |
| t, l = _count_jsonl_lines(FEEDBACK_FILE) | |
| _local_count["total"], _local_count["liked"] = t, l | |
| def _stats_str() -> str: | |
| t = _local_count["total"] | |
| l = _local_count["liked"] | |
| repo_link = f"[`{FEEDBACK_REPO}`](https://huggingface.co/datasets/{FEEDBACK_REPO})" | |
| sync = "synced every 5 min" if scheduler else "**HF_TOKEN missing β not syncing**" | |
| return ( | |
| f"**{t}** records this session Β· π {l} Β· π {t - l} \n" | |
| f"Pushed to {repo_link} ({sync})" | |
| ) | |
| def save_feedback(evt: gr.LikeData, history: list, system: str) -> str: | |
| """Handle a thumbs-up / thumbs-down click on a chat message.""" | |
| if evt.liked is None: | |
| return "rating cleared (nothing saved)" | |
| idx = evt.index[0] if isinstance(evt.index, (list, tuple)) else evt.index | |
| if not isinstance(idx, int) or idx < 0 or idx >= len(history): | |
| return f"bad index {evt.index!r}" | |
| msg = history[idx] | |
| if msg.get("role") != "assistant": | |
| return "skipped non-assistant message" | |
| record = { | |
| "ts": time.strftime("%Y-%m-%dT%H:%M:%S"), | |
| "system": (system or DEFAULT_SYSTEM).strip(), | |
| "prompt_messages": history[:idx], | |
| "response": msg.get("content", ""), | |
| "liked": bool(evt.liked), | |
| } | |
| # Write under the scheduler's lock (or our own) so the background push | |
| # never sees a half-written line. | |
| lock = scheduler.lock if scheduler else _local_lock | |
| with lock: | |
| with FEEDBACK_FILE.open("a") as f: | |
| f.write(json.dumps(record, ensure_ascii=False) + "\n") | |
| _local_count["total"] += 1 | |
| if record["liked"]: | |
| _local_count["liked"] += 1 | |
| verdict = "π good" if evt.liked else "π bad" | |
| return f"saved {verdict} Β· {_local_count['total']} this session" | |
| # ------------------------------------------------------------------------ | |
| # Chat callback | |
| def chat(user_msg, history, system, max_new, temperature, top_k, top_p, rep_penalty): | |
| """Stream a reply; yield updated chatbot history after each token.""" | |
| user_msg = (user_msg or "").strip() | |
| if not user_msg: | |
| yield history, "" | |
| return | |
| history = list(history) + [ | |
| {"role": "user", "content": user_msg}, | |
| {"role": "assistant", "content": ""}, | |
| ] | |
| prior = history[:-2] | |
| ids = render_prompt_ids(system or DEFAULT_SYSTEM, prior, user_msg) | |
| room = MCFG.seq_len - int(max_new) | |
| if len(ids) > room > 0: | |
| ids = ids[-room:] | |
| try: | |
| for partial in stream_reply(ids, int(max_new), float(temperature), | |
| int(top_k), float(top_p), float(rep_penalty)): | |
| history[-1]["content"] = partial | |
| yield history, "" | |
| except Exception: | |
| history[-1]["content"] = f"[error] {traceback.format_exc()}" | |
| yield history, "" | |
| # ------------------------------------------------------------------------ | |
| # UI | |
| with gr.Blocks( | |
| title="intellite 100M β RLHF collector", | |
| theme=gr.themes.Soft(primary_hue="indigo", neutral_hue="slate"), | |
| ) as demo: | |
| gr.Markdown( | |
| f"# intellite 100M β RLHF data collector\n" | |
| f"{MCFG.d_model}d Γ {MCFG.n_layers}L Γ {MCFG.n_heads}h " | |
| f"({N_PARAMS/1e6:.1f}M params) Β· {TOKENS_SEEN/1e6:.0f}M SFT tokens Β· " | |
| f"val_loss {BEST_VAL:.3f} Β· device `{DEVICE}` \n" | |
| f"**Please rate every response with π or π.** Ratings auto-sync to " | |
| f"[`{FEEDBACK_REPO}`](https://huggingface.co/datasets/{FEEDBACK_REPO}) " | |
| f"every 5 minutes for RLHF training." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot( | |
| type="messages", | |
| height=520, | |
| show_copy_button=True, | |
| avatar_images=(None, None), | |
| ) | |
| msg = gr.Textbox( | |
| placeholder="Your message β Enter to send", | |
| lines=2, | |
| show_label=False, | |
| autofocus=True, | |
| ) | |
| with gr.Row(): | |
| send_btn = gr.Button("Send", variant="primary") | |
| clear_btn = gr.Button("Clear chat") | |
| feedback_status = gr.Markdown("_rate replies with π / π_") | |
| with gr.Column(scale=1): | |
| system = gr.Textbox(value=DEFAULT_SYSTEM, label="System prompt", lines=3) | |
| max_new = gr.Slider(16, 800, value=400, step=16, label="max new tokens") | |
| temp = gr.Slider(0.1, 1.5, value=0.7, step=0.05, label="temperature") | |
| top_k = gr.Slider(0, 200, value=50, step=1, label="top-k (0 = off)") | |
| top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top-p") | |
| rep = gr.Slider(1.0, 1.5, value=1.1, step=0.01, label="repetition penalty") | |
| gr.Markdown("### RLHF data") | |
| stats_md = gr.Markdown(_stats_str()) | |
| send_btn.click( | |
| chat, | |
| inputs=[msg, chatbot, system, max_new, temp, top_k, top_p, rep], | |
| outputs=[chatbot, msg], | |
| ) | |
| msg.submit( | |
| chat, | |
| inputs=[msg, chatbot, system, max_new, temp, top_k, top_p, rep], | |
| outputs=[chatbot, msg], | |
| ) | |
| clear_btn.click(lambda: [], None, chatbot, queue=False) | |
| chatbot.like( | |
| save_feedback, | |
| inputs=[chatbot, system], | |
| outputs=[feedback_status], | |
| ).then(lambda: _stats_str(), None, stats_md, queue=False) | |
| demo.queue() | |
| if __name__ == "__main__": | |
| demo.launch() | |