Spaces:
Running on Zero
Running on Zero
| #!/usr/bin/env python3 | |
| # app.py | |
| import os | |
| import sys | |
| import random | |
| import codecs | |
| from collections import Counter | |
| from typing import List, Dict, Set, Tuple, Optional, Any | |
| import torch | |
| import gradio as gr | |
| import spaces | |
| from transformers import AutoModelForCausalLM | |
| def set_seed(seed: int) -> None: | |
| random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| def build_prompt_text(user_prompt: str) -> str: | |
| return user_prompt | |
| def decode_base2_digits_strict(digits: List[int], *, encoding: str = "utf-8", errors: str = "replace") -> str: | |
| bits: List[int] = [] | |
| for d in digits: | |
| di = int(d) | |
| if di == 0 or di == 1: | |
| bits.append(di) | |
| n_full_bytes = len(bits) // 8 | |
| if n_full_bytes <= 0: | |
| return "" | |
| out = bytearray(n_full_bytes) | |
| j = 0 | |
| for i in range(n_full_bytes): | |
| b = 0 | |
| b = (b << 1) | bits[j + 0] | |
| b = (b << 1) | bits[j + 1] | |
| b = (b << 1) | bits[j + 2] | |
| b = (b << 1) | bits[j + 3] | |
| b = (b << 1) | bits[j + 4] | |
| b = (b << 1) | bits[j + 5] | |
| b = (b << 1) | bits[j + 6] | |
| b = (b << 1) | bits[j + 7] | |
| out[i] = b | |
| j += 8 | |
| bb = bytes(out) | |
| if encoding.lower() == "utf-8": | |
| inc = codecs.getincrementaldecoder("utf-8")(errors=errors) | |
| s = inc.decode(bb, final=False) | |
| s += inc.decode(b"", final=True) | |
| return s | |
| return bb.decode(encoding, errors=errors) | |
| def bytes_to_base2_digits_bytesafe(data: bytes) -> List[int]: | |
| digits: List[int] = [] | |
| for b in data: | |
| for i in range(7, -1, -1): | |
| digits.append((b >> i) & 1) | |
| return digits | |
| def text_to_base2_digits(text: str) -> List[int]: | |
| return bytes_to_base2_digits_bytesafe(text.encode("utf-8")) | |
| def wrap_base2_sequence_2(ids: List[int], bos_id: int, eos_id: int) -> List[int]: | |
| return [int(bos_id), *ids, int(eos_id)] | |
| def top_k_filter(logits: torch.Tensor, top_k: int) -> torch.Tensor: | |
| if top_k <= 0: | |
| return logits | |
| top_k = min(top_k, logits.size(-1)) | |
| vals, idx = torch.topk(logits, top_k, dim=-1) | |
| out = torch.full_like(logits, float("-inf")) | |
| out.scatter_(dim=-1, index=idx, src=vals) | |
| return out | |
| def top_p_filter(logits: torch.Tensor, top_p: float) -> torch.Tensor: | |
| if top_p >= 1.0: | |
| return logits | |
| sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1) | |
| probs = torch.softmax(sorted_logits, dim=-1) | |
| cum = torch.cumsum(probs, dim=-1) | |
| mask = cum > top_p | |
| mask[..., 0] = False | |
| sorted_logits = sorted_logits.masked_fill(mask, float("-inf")) | |
| unsorted = torch.full_like(logits, float("-inf")) | |
| unsorted.scatter_(dim=-1, index=sorted_idx, src=sorted_logits) | |
| return unsorted | |
| def apply_repetition_penalty_(logits_1d: torch.Tensor, token_ids: List[int], penalty: float) -> None: | |
| if penalty is None or penalty == 1.0 or penalty <= 0: | |
| return | |
| if not token_ids: | |
| return | |
| uniq = set(int(t) for t in token_ids) | |
| V = logits_1d.numel() | |
| for t in uniq: | |
| if t < 0 or t >= V: | |
| continue | |
| val = logits_1d[t] | |
| logits_1d[t] = val * penalty if val < 0 else val / penalty | |
| def apply_encoder_repetition_penalty_(logits_1d: torch.Tensor, prompt_ids: List[int], penalty: float) -> None: | |
| if penalty is None or penalty == 1.0 or penalty <= 0: | |
| return | |
| if not prompt_ids: | |
| return | |
| uniq = set(int(t) for t in prompt_ids) | |
| V = logits_1d.numel() | |
| for t in uniq: | |
| if t < 0 or t >= V: | |
| continue | |
| val = logits_1d[t] | |
| logits_1d[t] = val / penalty if val < 0 else val * penalty | |
| def apply_presence_frequency_penalties_( | |
| logits_1d: torch.Tensor, | |
| token_ids: List[int], | |
| presence_penalty: float, | |
| frequency_penalty: float, | |
| ) -> None: | |
| if (presence_penalty is None or presence_penalty == 0.0) and (frequency_penalty is None or frequency_penalty == 0.0): | |
| return | |
| if not token_ids: | |
| return | |
| counts = Counter(int(t) for t in token_ids) | |
| V = logits_1d.numel() | |
| if presence_penalty and presence_penalty != 0.0: | |
| for t in counts.keys(): | |
| if 0 <= t < V: | |
| logits_1d[t] -= float(presence_penalty) | |
| if frequency_penalty and frequency_penalty != 0.0: | |
| for t, c in counts.items(): | |
| if 0 <= t < V: | |
| logits_1d[t] -= float(frequency_penalty) * float(c) | |
| def get_banned_tokens_no_repeat_ngram(seq: List[int], n: int) -> Set[int]: | |
| if n <= 0: | |
| return set() | |
| L = len(seq) | |
| if L < n - 1: | |
| return set() | |
| prefix_len = n - 1 | |
| ngrams: Dict[Tuple[int, ...], Set[int]] = {} | |
| for i in range(L - n + 1): | |
| prefix = tuple(int(x) for x in seq[i:i + prefix_len]) | |
| nxt = int(seq[i + prefix_len]) | |
| s = ngrams.get(prefix) | |
| if s is None: | |
| s = set() | |
| ngrams[prefix] = s | |
| s.add(nxt) | |
| key = tuple(int(x) for x in seq[-prefix_len:]) | |
| return ngrams.get(key, set()) | |
| def mask_banned_tokens_(logits_1d: torch.Tensor, banned: Set[int]) -> None: | |
| if not banned: | |
| return | |
| V = logits_1d.numel() | |
| for t in banned: | |
| if 0 <= t < V: | |
| logits_1d[t] = float("-inf") | |
| def sample_next_token( | |
| logits_1d: torch.Tensor, | |
| do_sample: bool, | |
| temperature: float, | |
| top_p: float, | |
| top_k: int, | |
| ) -> int: | |
| if temperature <= 0: | |
| temperature = 1.0 | |
| logits = logits_1d / float(temperature) | |
| logits = top_k_filter(logits.unsqueeze(0), int(top_k))[0] | |
| logits = top_p_filter(logits.unsqueeze(0), float(top_p))[0] | |
| if not do_sample: | |
| return int(torch.argmax(logits, dim=-1).item()) | |
| probs = torch.softmax(logits, dim=-1) | |
| next_id = torch.multinomial(probs, num_samples=1) | |
| return int(next_id.item()) | |
| _MODEL: Optional[Any] = None | |
| _MODEL_KEY: Optional[str] = None | |
| def get_device() -> str: | |
| return "cuda" if torch.cuda.is_available() else "cpu" | |
| def load_model(repo: str, revision: str, hf_token: str, trust_remote_code: bool) -> Any: | |
| global _MODEL, _MODEL_KEY | |
| device = get_device() | |
| key = f"{repo}@{revision or ''}@{device}" | |
| if _MODEL is not None and _MODEL_KEY == key: | |
| return _MODEL | |
| tok = hf_token or os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACEHUB_API_TOKEN") | |
| print(f"[load_model] repo={repo} revision={revision or 'None'} device={device} trust_remote_code={trust_remote_code}", flush=True) | |
| m = AutoModelForCausalLM.from_pretrained( | |
| repo, | |
| revision=revision if revision else None, | |
| token=tok, | |
| trust_remote_code=bool(trust_remote_code), | |
| torch_dtype=None, | |
| low_cpu_mem_usage=True, | |
| ) | |
| m.to(device) | |
| m.eval() | |
| if hasattr(m, "config") and m.config is not None: | |
| m.config.use_cache = True | |
| _MODEL = m | |
| _MODEL_KEY = key | |
| return _MODEL | |
| def run_infer( | |
| repo: str, | |
| revision: str, | |
| hf_token: str, | |
| trust_remote_code: bool, | |
| bos_id: int, | |
| eos_id: int, | |
| user_prompt: str, | |
| max_new_tokens: int, | |
| do_sample: bool, | |
| temperature: float, | |
| top_p: float, | |
| top_k: int, | |
| seed: int, | |
| stop_on_eos: bool, | |
| skip_special_decode: bool, | |
| repetition_penalty: float, | |
| presence_penalty: float, | |
| frequency_penalty: float, | |
| encoder_repetition_penalty: float, | |
| no_repeat_ngram_size: int, | |
| show_logs: bool, | |
| ) -> Dict[str, Any]: | |
| print("[DEBUG] APP VERSION = 2026-04-15-BINARY-RAW-PROMPT", flush=True) | |
| print(f"[DEBUG] repo={repo} revision={revision}", flush=True) | |
| print(f"[DEBUG] torch={torch.__version__}", flush=True) | |
| print(f"[DEBUG] cuda={torch.version.cuda}", flush=True) | |
| print(f"[DEBUG] device={get_device()}", flush=True) | |
| if seed is None: | |
| seed = 1234 | |
| seed = int(seed) | |
| if seed < 0: | |
| seed = random.randint(0, 2**31 - 1) | |
| set_seed(seed) | |
| model = load_model(repo, revision, hf_token, trust_remote_code) | |
| device = get_device() | |
| bos_id = int(bos_id) | |
| eos_id = int(eos_id) | |
| prompt_text = build_prompt_text(user_prompt) | |
| prompt_digits = text_to_base2_digits(prompt_text) | |
| input_ids: List[int] = wrap_base2_sequence_2(prompt_digits, bos_id, eos_id) + [bos_id] | |
| prompt_ids_for_penalty: List[int] = list(input_ids) | |
| out_ids = list(input_ids) | |
| gen_ids: List[int] = [] | |
| if show_logs: | |
| print(f"\n[Seed] {seed}", flush=True) | |
| print("=== INPUT TEXT ===", flush=True) | |
| print(prompt_text, flush=True) | |
| print("=== INPUT IDS ===", flush=True) | |
| print(input_ids, flush=True) | |
| print(f"[info] input_len={len(input_ids)} base=2 bos_id={bos_id} eos_id={eos_id} device={device}", flush=True) | |
| for _ in range(int(max_new_tokens)): | |
| x = torch.tensor([out_ids], dtype=torch.long, device=device) | |
| out = model(input_ids=x, use_cache=True) | |
| next_logits = out.logits[0, -1, :].clone() | |
| apply_encoder_repetition_penalty_(next_logits, prompt_ids_for_penalty, float(encoder_repetition_penalty)) | |
| apply_repetition_penalty_(next_logits, out_ids, float(repetition_penalty)) | |
| apply_presence_frequency_penalties_(next_logits, out_ids, float(presence_penalty), float(frequency_penalty)) | |
| if int(no_repeat_ngram_size) > 0: | |
| banned = get_banned_tokens_no_repeat_ngram(out_ids, int(no_repeat_ngram_size)) | |
| mask_banned_tokens_(next_logits, banned) | |
| next_id = sample_next_token( | |
| next_logits, | |
| do_sample=bool(do_sample), | |
| temperature=float(temperature), | |
| top_p=float(top_p), | |
| top_k=int(top_k), | |
| ) | |
| if next_id == eos_id and bool(stop_on_eos): | |
| out_ids.append(int(next_id)) | |
| break | |
| out_ids.append(int(next_id)) | |
| if bool(skip_special_decode): | |
| if next_id != bos_id and next_id != eos_id: | |
| gen_ids.append(int(next_id)) | |
| else: | |
| gen_ids.append(int(next_id)) | |
| decoded = decode_base2_digits_strict(gen_ids, encoding="utf-8", errors="replace") | |
| if show_logs: | |
| print("\n=== OUTPUT IDS (FULL) ===", flush=True) | |
| print(out_ids, flush=True) | |
| print("\n=== OUTPUT IDS (GENERATED ONLY) ===", flush=True) | |
| print(out_ids[len(input_ids):], flush=True) | |
| print("\n=== DECODE (GENERATED ONLY) ===", flush=True) | |
| print(decoded, flush=True) | |
| return { | |
| "seed": seed, | |
| "prompt_text": prompt_text, | |
| "input_ids": input_ids, | |
| "out_ids_full": out_ids, | |
| "out_ids_generated": out_ids[len(input_ids):], | |
| "gen_ids_for_decode": gen_ids, | |
| "decoded": decoded, | |
| } | |
| def ensure_history_msgs(history) -> List[Dict[str, str]]: | |
| if history is None: | |
| return [] | |
| if isinstance(history, list): | |
| if len(history) == 0: | |
| return [] | |
| if isinstance(history[0], dict) and "role" in history[0] and "content" in history[0]: | |
| return history | |
| return [] | |
| def on_send( | |
| user_msg: str, | |
| history, | |
| repo, | |
| revision, | |
| hf_token, | |
| trust_remote_code, | |
| bos_id, | |
| eos_id, | |
| max_new_tokens, | |
| do_sample, | |
| temperature, | |
| top_p, | |
| top_k, | |
| seed, | |
| stop_on_eos, | |
| skip_special_decode, | |
| repetition_penalty, | |
| presence_penalty, | |
| frequency_penalty, | |
| encoder_repetition_penalty, | |
| no_repeat_ngram_size, | |
| show_debug, | |
| show_logs, | |
| ): | |
| history_msgs = ensure_history_msgs(history) | |
| user_msg = (user_msg or "").strip() | |
| if user_msg == "": | |
| return history_msgs, history_msgs, "" | |
| history_msgs.append({"role": "user", "content": user_msg}) | |
| res = run_infer( | |
| repo=str(repo), | |
| revision=str(revision or ""), | |
| hf_token=str(hf_token or ""), | |
| trust_remote_code=bool(trust_remote_code), | |
| bos_id=int(bos_id), | |
| eos_id=int(eos_id), | |
| user_prompt=str(user_msg), | |
| max_new_tokens=int(max_new_tokens), | |
| do_sample=bool(do_sample), | |
| temperature=float(temperature), | |
| top_p=float(top_p), | |
| top_k=int(top_k), | |
| seed=int(seed), | |
| stop_on_eos=bool(stop_on_eos), | |
| skip_special_decode=bool(skip_special_decode), | |
| repetition_penalty=float(repetition_penalty), | |
| presence_penalty=float(presence_penalty), | |
| frequency_penalty=float(frequency_penalty), | |
| encoder_repetition_penalty=float(encoder_repetition_penalty), | |
| no_repeat_ngram_size=int(no_repeat_ngram_size), | |
| show_logs=bool(show_logs), | |
| ) | |
| assistant_text = res["decoded"] | |
| history_msgs.append({"role": "assistant", "content": assistant_text}) | |
| debug_txt = "" | |
| if bool(show_debug): | |
| debug_txt = ( | |
| f"[Seed] {res['seed']}\n" | |
| f"=== INPUT TEXT ===\n{res['prompt_text']}\n\n" | |
| f"=== INPUT IDS ===\n{res['input_ids']}\n\n" | |
| f"=== OUTPUT IDS (FULL) ===\n{res['out_ids_full']}\n\n" | |
| f"=== OUTPUT IDS (GENERATED ONLY) ===\n{res['out_ids_generated']}\n\n" | |
| f"=== DECODE (GENERATED ONLY) ===\n{assistant_text}\n" | |
| ) | |
| return history_msgs, history_msgs, debug_txt | |
| def on_clear(): | |
| return [], [], "" | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Binary-LLM-POC") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| chat = gr.Chatbot(label="Chat") | |
| state = gr.State([]) | |
| user_in = gr.Textbox(label="User", placeholder="Tape ton message…", lines=3) | |
| with gr.Row(): | |
| send_btn = gr.Button("Send", variant="primary") | |
| clear_btn = gr.Button("Clear") | |
| debug_out = gr.Textbox(label="Debug (optional)", lines=18) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Model / HF") | |
| repo = gr.Textbox(label="Repo", value="PhysiQuanty/Binary-LLM-POC") | |
| revision = gr.Textbox(label="Revision (optional)", value="") | |
| hf_token = gr.Textbox(label="HF_TOKEN (optional)", value="", type="password") | |
| trust_remote_code = gr.Checkbox(label="trust_remote_code", value=True) | |
| gr.Markdown("### Specials / base") | |
| bos_id = gr.Number(label="bos_id", value=2, precision=0) | |
| eos_id = gr.Number(label="eos_id", value=3, precision=0) | |
| gr.Markdown("### Sampling") | |
| max_new_tokens = gr.Slider(label="max_new_tokens", minimum=1, maximum=4096, value=2048, step=1) | |
| do_sample = gr.Checkbox(label="do_sample", value=True) | |
| temperature = gr.Slider(label="temperature", minimum=0.01, maximum=2.0, value=0.7, step=0.01) | |
| top_p = gr.Slider(label="top_p", minimum=0.0, maximum=1.0, value=1.0, step=0.001) | |
| top_k = gr.Slider(label="top_k", minimum=0, maximum=50, value=50, step=1) | |
| seed = gr.Number(label="seed (-1=random)", value=-1, precision=0) | |
| gr.Markdown("### Stops / decode") | |
| stop_on_eos = gr.Checkbox(label="stop_on_eos", value=True) | |
| skip_special_decode = gr.Checkbox(label="skip_special_decode", value=True) | |
| gr.Markdown("### Penalties (5)") | |
| repetition_penalty = gr.Slider(label="repetition_penalty", minimum=0.0, maximum=3.0, value=1.0, step=0.01) | |
| encoder_repetition_penalty = gr.Slider(label="encoder_repetition_penalty", minimum=0.0, maximum=3.0, value=1.0, step=0.01) | |
| presence_penalty = gr.Slider(label="presence_penalty", minimum=0.0, maximum=2.0, value=0.0, step=0.01) | |
| frequency_penalty = gr.Slider(label="frequency_penalty", minimum=0.0, maximum=2.0, value=0.0, step=0.01) | |
| no_repeat_ngram_size = gr.Slider(label="no_repeat_ngram_size", minimum=0, maximum=20, value=0, step=1) | |
| gr.Markdown("### Debug / logs") | |
| show_debug = gr.Checkbox(label="Show debug block (UI)", value=True) | |
| show_logs = gr.Checkbox(label="Print logs to container logs", value=True) | |
| send_btn.click( | |
| fn=on_send, | |
| inputs=[ | |
| user_in, state, | |
| repo, revision, hf_token, trust_remote_code, | |
| bos_id, eos_id, | |
| max_new_tokens, do_sample, temperature, top_p, top_k, | |
| seed, stop_on_eos, skip_special_decode, | |
| repetition_penalty, presence_penalty, frequency_penalty, encoder_repetition_penalty, no_repeat_ngram_size, | |
| show_debug, show_logs | |
| ], | |
| outputs=[chat, state, debug_out], | |
| ) | |
| user_in.submit( | |
| fn=on_send, | |
| inputs=[ | |
| user_in, state, | |
| repo, revision, hf_token, trust_remote_code, | |
| bos_id, eos_id, | |
| max_new_tokens, do_sample, temperature, top_p, top_k, | |
| seed, stop_on_eos, skip_special_decode, | |
| repetition_penalty, presence_penalty, frequency_penalty, encoder_repetition_penalty, no_repeat_ngram_size, | |
| show_debug, show_logs | |
| ], | |
| outputs=[chat, state, debug_out], | |
| ) | |
| clear_btn.click(fn=on_clear, inputs=[], outputs=[chat, state, debug_out]) | |
| send_btn.click(lambda: "", inputs=[], outputs=[user_in]) | |
| user_in.submit(lambda: "", inputs=[], outputs=[user_in]) | |
| demo.queue().launch(ssr_mode=False) |