#!/usr/bin/env python3 # app.py import os import sys import random import codecs from collections import Counter from typing import List, Dict, Set, Tuple, Optional, Any import torch import gradio as gr import spaces from transformers import AutoModelForCausalLM def set_seed(seed: int) -> None: random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def build_prompt_text(user_prompt: str) -> str: return user_prompt def decode_base2_digits_strict(digits: List[int], *, encoding: str = "utf-8", errors: str = "replace") -> str: bits: List[int] = [] for d in digits: di = int(d) if di == 0 or di == 1: bits.append(di) n_full_bytes = len(bits) // 8 if n_full_bytes <= 0: return "" out = bytearray(n_full_bytes) j = 0 for i in range(n_full_bytes): b = 0 b = (b << 1) | bits[j + 0] b = (b << 1) | bits[j + 1] b = (b << 1) | bits[j + 2] b = (b << 1) | bits[j + 3] b = (b << 1) | bits[j + 4] b = (b << 1) | bits[j + 5] b = (b << 1) | bits[j + 6] b = (b << 1) | bits[j + 7] out[i] = b j += 8 bb = bytes(out) if encoding.lower() == "utf-8": inc = codecs.getincrementaldecoder("utf-8")(errors=errors) s = inc.decode(bb, final=False) s += inc.decode(b"", final=True) return s return bb.decode(encoding, errors=errors) def bytes_to_base2_digits_bytesafe(data: bytes) -> List[int]: digits: List[int] = [] for b in data: for i in range(7, -1, -1): digits.append((b >> i) & 1) return digits def text_to_base2_digits(text: str) -> List[int]: return bytes_to_base2_digits_bytesafe(text.encode("utf-8")) def wrap_base2_sequence_2(ids: List[int], bos_id: int, eos_id: int) -> List[int]: return [int(bos_id), *ids, int(eos_id)] def top_k_filter(logits: torch.Tensor, top_k: int) -> torch.Tensor: if top_k <= 0: return logits top_k = min(top_k, logits.size(-1)) vals, idx = torch.topk(logits, top_k, dim=-1) out = torch.full_like(logits, float("-inf")) out.scatter_(dim=-1, index=idx, src=vals) return out def top_p_filter(logits: torch.Tensor, top_p: float) -> torch.Tensor: if top_p >= 1.0: return logits sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1) probs = torch.softmax(sorted_logits, dim=-1) cum = torch.cumsum(probs, dim=-1) mask = cum > top_p mask[..., 0] = False sorted_logits = sorted_logits.masked_fill(mask, float("-inf")) unsorted = torch.full_like(logits, float("-inf")) unsorted.scatter_(dim=-1, index=sorted_idx, src=sorted_logits) return unsorted def apply_repetition_penalty_(logits_1d: torch.Tensor, token_ids: List[int], penalty: float) -> None: if penalty is None or penalty == 1.0 or penalty <= 0: return if not token_ids: return uniq = set(int(t) for t in token_ids) V = logits_1d.numel() for t in uniq: if t < 0 or t >= V: continue val = logits_1d[t] logits_1d[t] = val * penalty if val < 0 else val / penalty def apply_encoder_repetition_penalty_(logits_1d: torch.Tensor, prompt_ids: List[int], penalty: float) -> None: if penalty is None or penalty == 1.0 or penalty <= 0: return if not prompt_ids: return uniq = set(int(t) for t in prompt_ids) V = logits_1d.numel() for t in uniq: if t < 0 or t >= V: continue val = logits_1d[t] logits_1d[t] = val / penalty if val < 0 else val * penalty def apply_presence_frequency_penalties_( logits_1d: torch.Tensor, token_ids: List[int], presence_penalty: float, frequency_penalty: float, ) -> None: if (presence_penalty is None or presence_penalty == 0.0) and (frequency_penalty is None or frequency_penalty == 0.0): return if not token_ids: return counts = Counter(int(t) for t in token_ids) V = logits_1d.numel() if presence_penalty and presence_penalty != 0.0: for t in counts.keys(): if 0 <= t < V: logits_1d[t] -= float(presence_penalty) if frequency_penalty and frequency_penalty != 0.0: for t, c in counts.items(): if 0 <= t < V: logits_1d[t] -= float(frequency_penalty) * float(c) def get_banned_tokens_no_repeat_ngram(seq: List[int], n: int) -> Set[int]: if n <= 0: return set() L = len(seq) if L < n - 1: return set() prefix_len = n - 1 ngrams: Dict[Tuple[int, ...], Set[int]] = {} for i in range(L - n + 1): prefix = tuple(int(x) for x in seq[i:i + prefix_len]) nxt = int(seq[i + prefix_len]) s = ngrams.get(prefix) if s is None: s = set() ngrams[prefix] = s s.add(nxt) key = tuple(int(x) for x in seq[-prefix_len:]) return ngrams.get(key, set()) def mask_banned_tokens_(logits_1d: torch.Tensor, banned: Set[int]) -> None: if not banned: return V = logits_1d.numel() for t in banned: if 0 <= t < V: logits_1d[t] = float("-inf") def sample_next_token( logits_1d: torch.Tensor, do_sample: bool, temperature: float, top_p: float, top_k: int, ) -> int: if temperature <= 0: temperature = 1.0 logits = logits_1d / float(temperature) logits = top_k_filter(logits.unsqueeze(0), int(top_k))[0] logits = top_p_filter(logits.unsqueeze(0), float(top_p))[0] if not do_sample: return int(torch.argmax(logits, dim=-1).item()) probs = torch.softmax(logits, dim=-1) next_id = torch.multinomial(probs, num_samples=1) return int(next_id.item()) _MODEL: Optional[Any] = None _MODEL_KEY: Optional[str] = None def get_device() -> str: return "cuda" if torch.cuda.is_available() else "cpu" def load_model(repo: str, revision: str, hf_token: str, trust_remote_code: bool) -> Any: global _MODEL, _MODEL_KEY device = get_device() key = f"{repo}@{revision or ''}@{device}" if _MODEL is not None and _MODEL_KEY == key: return _MODEL tok = hf_token or os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACEHUB_API_TOKEN") print(f"[load_model] repo={repo} revision={revision or 'None'} device={device} trust_remote_code={trust_remote_code}", flush=True) m = AutoModelForCausalLM.from_pretrained( repo, revision=revision if revision else None, token=tok, trust_remote_code=bool(trust_remote_code), torch_dtype=None, low_cpu_mem_usage=True, ) m.to(device) m.eval() if hasattr(m, "config") and m.config is not None: m.config.use_cache = True _MODEL = m _MODEL_KEY = key return _MODEL @spaces.GPU @torch.no_grad() def run_infer( repo: str, revision: str, hf_token: str, trust_remote_code: bool, bos_id: int, eos_id: int, user_prompt: str, max_new_tokens: int, do_sample: bool, temperature: float, top_p: float, top_k: int, seed: int, stop_on_eos: bool, skip_special_decode: bool, repetition_penalty: float, presence_penalty: float, frequency_penalty: float, encoder_repetition_penalty: float, no_repeat_ngram_size: int, show_logs: bool, ) -> Dict[str, Any]: print("[DEBUG] APP VERSION = 2026-04-15-BINARY-RAW-PROMPT", flush=True) print(f"[DEBUG] repo={repo} revision={revision}", flush=True) print(f"[DEBUG] torch={torch.__version__}", flush=True) print(f"[DEBUG] cuda={torch.version.cuda}", flush=True) print(f"[DEBUG] device={get_device()}", flush=True) if seed is None: seed = 1234 seed = int(seed) if seed < 0: seed = random.randint(0, 2**31 - 1) set_seed(seed) model = load_model(repo, revision, hf_token, trust_remote_code) device = get_device() bos_id = int(bos_id) eos_id = int(eos_id) prompt_text = build_prompt_text(user_prompt) prompt_digits = text_to_base2_digits(prompt_text) input_ids: List[int] = wrap_base2_sequence_2(prompt_digits, bos_id, eos_id) + [bos_id] prompt_ids_for_penalty: List[int] = list(input_ids) out_ids = list(input_ids) gen_ids: List[int] = [] if show_logs: print(f"\n[Seed] {seed}", flush=True) print("=== INPUT TEXT ===", flush=True) print(prompt_text, flush=True) print("=== INPUT IDS ===", flush=True) print(input_ids, flush=True) print(f"[info] input_len={len(input_ids)} base=2 bos_id={bos_id} eos_id={eos_id} device={device}", flush=True) for _ in range(int(max_new_tokens)): x = torch.tensor([out_ids], dtype=torch.long, device=device) out = model(input_ids=x, use_cache=True) next_logits = out.logits[0, -1, :].clone() apply_encoder_repetition_penalty_(next_logits, prompt_ids_for_penalty, float(encoder_repetition_penalty)) apply_repetition_penalty_(next_logits, out_ids, float(repetition_penalty)) apply_presence_frequency_penalties_(next_logits, out_ids, float(presence_penalty), float(frequency_penalty)) if int(no_repeat_ngram_size) > 0: banned = get_banned_tokens_no_repeat_ngram(out_ids, int(no_repeat_ngram_size)) mask_banned_tokens_(next_logits, banned) next_id = sample_next_token( next_logits, do_sample=bool(do_sample), temperature=float(temperature), top_p=float(top_p), top_k=int(top_k), ) if next_id == eos_id and bool(stop_on_eos): out_ids.append(int(next_id)) break out_ids.append(int(next_id)) if bool(skip_special_decode): if next_id != bos_id and next_id != eos_id: gen_ids.append(int(next_id)) else: gen_ids.append(int(next_id)) decoded = decode_base2_digits_strict(gen_ids, encoding="utf-8", errors="replace") if show_logs: print("\n=== OUTPUT IDS (FULL) ===", flush=True) print(out_ids, flush=True) print("\n=== OUTPUT IDS (GENERATED ONLY) ===", flush=True) print(out_ids[len(input_ids):], flush=True) print("\n=== DECODE (GENERATED ONLY) ===", flush=True) print(decoded, flush=True) return { "seed": seed, "prompt_text": prompt_text, "input_ids": input_ids, "out_ids_full": out_ids, "out_ids_generated": out_ids[len(input_ids):], "gen_ids_for_decode": gen_ids, "decoded": decoded, } def ensure_history_msgs(history) -> List[Dict[str, str]]: if history is None: return [] if isinstance(history, list): if len(history) == 0: return [] if isinstance(history[0], dict) and "role" in history[0] and "content" in history[0]: return history return [] def on_send( user_msg: str, history, repo, revision, hf_token, trust_remote_code, bos_id, eos_id, max_new_tokens, do_sample, temperature, top_p, top_k, seed, stop_on_eos, skip_special_decode, repetition_penalty, presence_penalty, frequency_penalty, encoder_repetition_penalty, no_repeat_ngram_size, show_debug, show_logs, ): history_msgs = ensure_history_msgs(history) user_msg = (user_msg or "").strip() if user_msg == "": return history_msgs, history_msgs, "" history_msgs.append({"role": "user", "content": user_msg}) res = run_infer( repo=str(repo), revision=str(revision or ""), hf_token=str(hf_token or ""), trust_remote_code=bool(trust_remote_code), bos_id=int(bos_id), eos_id=int(eos_id), user_prompt=str(user_msg), max_new_tokens=int(max_new_tokens), do_sample=bool(do_sample), temperature=float(temperature), top_p=float(top_p), top_k=int(top_k), seed=int(seed), stop_on_eos=bool(stop_on_eos), skip_special_decode=bool(skip_special_decode), repetition_penalty=float(repetition_penalty), presence_penalty=float(presence_penalty), frequency_penalty=float(frequency_penalty), encoder_repetition_penalty=float(encoder_repetition_penalty), no_repeat_ngram_size=int(no_repeat_ngram_size), show_logs=bool(show_logs), ) assistant_text = res["decoded"] history_msgs.append({"role": "assistant", "content": assistant_text}) debug_txt = "" if bool(show_debug): debug_txt = ( f"[Seed] {res['seed']}\n" f"=== INPUT TEXT ===\n{res['prompt_text']}\n\n" f"=== INPUT IDS ===\n{res['input_ids']}\n\n" f"=== OUTPUT IDS (FULL) ===\n{res['out_ids_full']}\n\n" f"=== OUTPUT IDS (GENERATED ONLY) ===\n{res['out_ids_generated']}\n\n" f"=== DECODE (GENERATED ONLY) ===\n{assistant_text}\n" ) return history_msgs, history_msgs, debug_txt def on_clear(): return [], [], "" with gr.Blocks() as demo: gr.Markdown("## Binary-LLM-POC") with gr.Row(): with gr.Column(scale=2): chat = gr.Chatbot(label="Chat") state = gr.State([]) user_in = gr.Textbox(label="User", placeholder="Tape ton message…", lines=3) with gr.Row(): send_btn = gr.Button("Send", variant="primary") clear_btn = gr.Button("Clear") debug_out = gr.Textbox(label="Debug (optional)", lines=18) with gr.Column(scale=1): gr.Markdown("### Model / HF") repo = gr.Textbox(label="Repo", value="PhysiQuanty/Binary-LLM-POC") revision = gr.Textbox(label="Revision (optional)", value="") hf_token = gr.Textbox(label="HF_TOKEN (optional)", value="", type="password") trust_remote_code = gr.Checkbox(label="trust_remote_code", value=True) gr.Markdown("### Specials / base") bos_id = gr.Number(label="bos_id", value=2, precision=0) eos_id = gr.Number(label="eos_id", value=3, precision=0) gr.Markdown("### Sampling") max_new_tokens = gr.Slider(label="max_new_tokens", minimum=1, maximum=4096, value=2048, step=1) do_sample = gr.Checkbox(label="do_sample", value=True) temperature = gr.Slider(label="temperature", minimum=0.01, maximum=2.0, value=0.7, step=0.01) top_p = gr.Slider(label="top_p", minimum=0.0, maximum=1.0, value=1.0, step=0.001) top_k = gr.Slider(label="top_k", minimum=0, maximum=50, value=50, step=1) seed = gr.Number(label="seed (-1=random)", value=-1, precision=0) gr.Markdown("### Stops / decode") stop_on_eos = gr.Checkbox(label="stop_on_eos", value=True) skip_special_decode = gr.Checkbox(label="skip_special_decode", value=True) gr.Markdown("### Penalties (5)") repetition_penalty = gr.Slider(label="repetition_penalty", minimum=0.0, maximum=3.0, value=1.0, step=0.01) encoder_repetition_penalty = gr.Slider(label="encoder_repetition_penalty", minimum=0.0, maximum=3.0, value=1.0, step=0.01) presence_penalty = gr.Slider(label="presence_penalty", minimum=0.0, maximum=2.0, value=0.0, step=0.01) frequency_penalty = gr.Slider(label="frequency_penalty", minimum=0.0, maximum=2.0, value=0.0, step=0.01) no_repeat_ngram_size = gr.Slider(label="no_repeat_ngram_size", minimum=0, maximum=20, value=0, step=1) gr.Markdown("### Debug / logs") show_debug = gr.Checkbox(label="Show debug block (UI)", value=True) show_logs = gr.Checkbox(label="Print logs to container logs", value=True) send_btn.click( fn=on_send, inputs=[ user_in, state, repo, revision, hf_token, trust_remote_code, bos_id, eos_id, max_new_tokens, do_sample, temperature, top_p, top_k, seed, stop_on_eos, skip_special_decode, repetition_penalty, presence_penalty, frequency_penalty, encoder_repetition_penalty, no_repeat_ngram_size, show_debug, show_logs ], outputs=[chat, state, debug_out], ) user_in.submit( fn=on_send, inputs=[ user_in, state, repo, revision, hf_token, trust_remote_code, bos_id, eos_id, max_new_tokens, do_sample, temperature, top_p, top_k, seed, stop_on_eos, skip_special_decode, repetition_penalty, presence_penalty, frequency_penalty, encoder_repetition_penalty, no_repeat_ngram_size, show_debug, show_logs ], outputs=[chat, state, debug_out], ) clear_btn.click(fn=on_clear, inputs=[], outputs=[chat, state, debug_out]) send_btn.click(lambda: "", inputs=[], outputs=[user_in]) user_in.submit(lambda: "", inputs=[], outputs=[user_in]) demo.queue().launch(ssr_mode=False)