Binary-LLM-POC / app.py
PhysiQuanty's picture
Create app.py
45bfd05 verified
#!/usr/bin/env python3
# app.py
import os
import sys
import random
import codecs
from collections import Counter
from typing import List, Dict, Set, Tuple, Optional, Any
import torch
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM
def set_seed(seed: int) -> None:
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def build_prompt_text(user_prompt: str) -> str:
return user_prompt
def decode_base2_digits_strict(digits: List[int], *, encoding: str = "utf-8", errors: str = "replace") -> str:
bits: List[int] = []
for d in digits:
di = int(d)
if di == 0 or di == 1:
bits.append(di)
n_full_bytes = len(bits) // 8
if n_full_bytes <= 0:
return ""
out = bytearray(n_full_bytes)
j = 0
for i in range(n_full_bytes):
b = 0
b = (b << 1) | bits[j + 0]
b = (b << 1) | bits[j + 1]
b = (b << 1) | bits[j + 2]
b = (b << 1) | bits[j + 3]
b = (b << 1) | bits[j + 4]
b = (b << 1) | bits[j + 5]
b = (b << 1) | bits[j + 6]
b = (b << 1) | bits[j + 7]
out[i] = b
j += 8
bb = bytes(out)
if encoding.lower() == "utf-8":
inc = codecs.getincrementaldecoder("utf-8")(errors=errors)
s = inc.decode(bb, final=False)
s += inc.decode(b"", final=True)
return s
return bb.decode(encoding, errors=errors)
def bytes_to_base2_digits_bytesafe(data: bytes) -> List[int]:
digits: List[int] = []
for b in data:
for i in range(7, -1, -1):
digits.append((b >> i) & 1)
return digits
def text_to_base2_digits(text: str) -> List[int]:
return bytes_to_base2_digits_bytesafe(text.encode("utf-8"))
def wrap_base2_sequence_2(ids: List[int], bos_id: int, eos_id: int) -> List[int]:
return [int(bos_id), *ids, int(eos_id)]
def top_k_filter(logits: torch.Tensor, top_k: int) -> torch.Tensor:
if top_k <= 0:
return logits
top_k = min(top_k, logits.size(-1))
vals, idx = torch.topk(logits, top_k, dim=-1)
out = torch.full_like(logits, float("-inf"))
out.scatter_(dim=-1, index=idx, src=vals)
return out
def top_p_filter(logits: torch.Tensor, top_p: float) -> torch.Tensor:
if top_p >= 1.0:
return logits
sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1)
probs = torch.softmax(sorted_logits, dim=-1)
cum = torch.cumsum(probs, dim=-1)
mask = cum > top_p
mask[..., 0] = False
sorted_logits = sorted_logits.masked_fill(mask, float("-inf"))
unsorted = torch.full_like(logits, float("-inf"))
unsorted.scatter_(dim=-1, index=sorted_idx, src=sorted_logits)
return unsorted
def apply_repetition_penalty_(logits_1d: torch.Tensor, token_ids: List[int], penalty: float) -> None:
if penalty is None or penalty == 1.0 or penalty <= 0:
return
if not token_ids:
return
uniq = set(int(t) for t in token_ids)
V = logits_1d.numel()
for t in uniq:
if t < 0 or t >= V:
continue
val = logits_1d[t]
logits_1d[t] = val * penalty if val < 0 else val / penalty
def apply_encoder_repetition_penalty_(logits_1d: torch.Tensor, prompt_ids: List[int], penalty: float) -> None:
if penalty is None or penalty == 1.0 or penalty <= 0:
return
if not prompt_ids:
return
uniq = set(int(t) for t in prompt_ids)
V = logits_1d.numel()
for t in uniq:
if t < 0 or t >= V:
continue
val = logits_1d[t]
logits_1d[t] = val / penalty if val < 0 else val * penalty
def apply_presence_frequency_penalties_(
logits_1d: torch.Tensor,
token_ids: List[int],
presence_penalty: float,
frequency_penalty: float,
) -> None:
if (presence_penalty is None or presence_penalty == 0.0) and (frequency_penalty is None or frequency_penalty == 0.0):
return
if not token_ids:
return
counts = Counter(int(t) for t in token_ids)
V = logits_1d.numel()
if presence_penalty and presence_penalty != 0.0:
for t in counts.keys():
if 0 <= t < V:
logits_1d[t] -= float(presence_penalty)
if frequency_penalty and frequency_penalty != 0.0:
for t, c in counts.items():
if 0 <= t < V:
logits_1d[t] -= float(frequency_penalty) * float(c)
def get_banned_tokens_no_repeat_ngram(seq: List[int], n: int) -> Set[int]:
if n <= 0:
return set()
L = len(seq)
if L < n - 1:
return set()
prefix_len = n - 1
ngrams: Dict[Tuple[int, ...], Set[int]] = {}
for i in range(L - n + 1):
prefix = tuple(int(x) for x in seq[i:i + prefix_len])
nxt = int(seq[i + prefix_len])
s = ngrams.get(prefix)
if s is None:
s = set()
ngrams[prefix] = s
s.add(nxt)
key = tuple(int(x) for x in seq[-prefix_len:])
return ngrams.get(key, set())
def mask_banned_tokens_(logits_1d: torch.Tensor, banned: Set[int]) -> None:
if not banned:
return
V = logits_1d.numel()
for t in banned:
if 0 <= t < V:
logits_1d[t] = float("-inf")
def sample_next_token(
logits_1d: torch.Tensor,
do_sample: bool,
temperature: float,
top_p: float,
top_k: int,
) -> int:
if temperature <= 0:
temperature = 1.0
logits = logits_1d / float(temperature)
logits = top_k_filter(logits.unsqueeze(0), int(top_k))[0]
logits = top_p_filter(logits.unsqueeze(0), float(top_p))[0]
if not do_sample:
return int(torch.argmax(logits, dim=-1).item())
probs = torch.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1)
return int(next_id.item())
_MODEL: Optional[Any] = None
_MODEL_KEY: Optional[str] = None
def get_device() -> str:
return "cuda" if torch.cuda.is_available() else "cpu"
def load_model(repo: str, revision: str, hf_token: str, trust_remote_code: bool) -> Any:
global _MODEL, _MODEL_KEY
device = get_device()
key = f"{repo}@{revision or ''}@{device}"
if _MODEL is not None and _MODEL_KEY == key:
return _MODEL
tok = hf_token or os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
print(f"[load_model] repo={repo} revision={revision or 'None'} device={device} trust_remote_code={trust_remote_code}", flush=True)
m = AutoModelForCausalLM.from_pretrained(
repo,
revision=revision if revision else None,
token=tok,
trust_remote_code=bool(trust_remote_code),
torch_dtype=None,
low_cpu_mem_usage=True,
)
m.to(device)
m.eval()
if hasattr(m, "config") and m.config is not None:
m.config.use_cache = True
_MODEL = m
_MODEL_KEY = key
return _MODEL
@spaces.GPU
@torch.no_grad()
def run_infer(
repo: str,
revision: str,
hf_token: str,
trust_remote_code: bool,
bos_id: int,
eos_id: int,
user_prompt: str,
max_new_tokens: int,
do_sample: bool,
temperature: float,
top_p: float,
top_k: int,
seed: int,
stop_on_eos: bool,
skip_special_decode: bool,
repetition_penalty: float,
presence_penalty: float,
frequency_penalty: float,
encoder_repetition_penalty: float,
no_repeat_ngram_size: int,
show_logs: bool,
) -> Dict[str, Any]:
print("[DEBUG] APP VERSION = 2026-04-15-BINARY-RAW-PROMPT", flush=True)
print(f"[DEBUG] repo={repo} revision={revision}", flush=True)
print(f"[DEBUG] torch={torch.__version__}", flush=True)
print(f"[DEBUG] cuda={torch.version.cuda}", flush=True)
print(f"[DEBUG] device={get_device()}", flush=True)
if seed is None:
seed = 1234
seed = int(seed)
if seed < 0:
seed = random.randint(0, 2**31 - 1)
set_seed(seed)
model = load_model(repo, revision, hf_token, trust_remote_code)
device = get_device()
bos_id = int(bos_id)
eos_id = int(eos_id)
prompt_text = build_prompt_text(user_prompt)
prompt_digits = text_to_base2_digits(prompt_text)
input_ids: List[int] = wrap_base2_sequence_2(prompt_digits, bos_id, eos_id) + [bos_id]
prompt_ids_for_penalty: List[int] = list(input_ids)
out_ids = list(input_ids)
gen_ids: List[int] = []
if show_logs:
print(f"\n[Seed] {seed}", flush=True)
print("=== INPUT TEXT ===", flush=True)
print(prompt_text, flush=True)
print("=== INPUT IDS ===", flush=True)
print(input_ids, flush=True)
print(f"[info] input_len={len(input_ids)} base=2 bos_id={bos_id} eos_id={eos_id} device={device}", flush=True)
for _ in range(int(max_new_tokens)):
x = torch.tensor([out_ids], dtype=torch.long, device=device)
out = model(input_ids=x, use_cache=True)
next_logits = out.logits[0, -1, :].clone()
apply_encoder_repetition_penalty_(next_logits, prompt_ids_for_penalty, float(encoder_repetition_penalty))
apply_repetition_penalty_(next_logits, out_ids, float(repetition_penalty))
apply_presence_frequency_penalties_(next_logits, out_ids, float(presence_penalty), float(frequency_penalty))
if int(no_repeat_ngram_size) > 0:
banned = get_banned_tokens_no_repeat_ngram(out_ids, int(no_repeat_ngram_size))
mask_banned_tokens_(next_logits, banned)
next_id = sample_next_token(
next_logits,
do_sample=bool(do_sample),
temperature=float(temperature),
top_p=float(top_p),
top_k=int(top_k),
)
if next_id == eos_id and bool(stop_on_eos):
out_ids.append(int(next_id))
break
out_ids.append(int(next_id))
if bool(skip_special_decode):
if next_id != bos_id and next_id != eos_id:
gen_ids.append(int(next_id))
else:
gen_ids.append(int(next_id))
decoded = decode_base2_digits_strict(gen_ids, encoding="utf-8", errors="replace")
if show_logs:
print("\n=== OUTPUT IDS (FULL) ===", flush=True)
print(out_ids, flush=True)
print("\n=== OUTPUT IDS (GENERATED ONLY) ===", flush=True)
print(out_ids[len(input_ids):], flush=True)
print("\n=== DECODE (GENERATED ONLY) ===", flush=True)
print(decoded, flush=True)
return {
"seed": seed,
"prompt_text": prompt_text,
"input_ids": input_ids,
"out_ids_full": out_ids,
"out_ids_generated": out_ids[len(input_ids):],
"gen_ids_for_decode": gen_ids,
"decoded": decoded,
}
def ensure_history_msgs(history) -> List[Dict[str, str]]:
if history is None:
return []
if isinstance(history, list):
if len(history) == 0:
return []
if isinstance(history[0], dict) and "role" in history[0] and "content" in history[0]:
return history
return []
def on_send(
user_msg: str,
history,
repo,
revision,
hf_token,
trust_remote_code,
bos_id,
eos_id,
max_new_tokens,
do_sample,
temperature,
top_p,
top_k,
seed,
stop_on_eos,
skip_special_decode,
repetition_penalty,
presence_penalty,
frequency_penalty,
encoder_repetition_penalty,
no_repeat_ngram_size,
show_debug,
show_logs,
):
history_msgs = ensure_history_msgs(history)
user_msg = (user_msg or "").strip()
if user_msg == "":
return history_msgs, history_msgs, ""
history_msgs.append({"role": "user", "content": user_msg})
res = run_infer(
repo=str(repo),
revision=str(revision or ""),
hf_token=str(hf_token or ""),
trust_remote_code=bool(trust_remote_code),
bos_id=int(bos_id),
eos_id=int(eos_id),
user_prompt=str(user_msg),
max_new_tokens=int(max_new_tokens),
do_sample=bool(do_sample),
temperature=float(temperature),
top_p=float(top_p),
top_k=int(top_k),
seed=int(seed),
stop_on_eos=bool(stop_on_eos),
skip_special_decode=bool(skip_special_decode),
repetition_penalty=float(repetition_penalty),
presence_penalty=float(presence_penalty),
frequency_penalty=float(frequency_penalty),
encoder_repetition_penalty=float(encoder_repetition_penalty),
no_repeat_ngram_size=int(no_repeat_ngram_size),
show_logs=bool(show_logs),
)
assistant_text = res["decoded"]
history_msgs.append({"role": "assistant", "content": assistant_text})
debug_txt = ""
if bool(show_debug):
debug_txt = (
f"[Seed] {res['seed']}\n"
f"=== INPUT TEXT ===\n{res['prompt_text']}\n\n"
f"=== INPUT IDS ===\n{res['input_ids']}\n\n"
f"=== OUTPUT IDS (FULL) ===\n{res['out_ids_full']}\n\n"
f"=== OUTPUT IDS (GENERATED ONLY) ===\n{res['out_ids_generated']}\n\n"
f"=== DECODE (GENERATED ONLY) ===\n{assistant_text}\n"
)
return history_msgs, history_msgs, debug_txt
def on_clear():
return [], [], ""
with gr.Blocks() as demo:
gr.Markdown("## Binary-LLM-POC")
with gr.Row():
with gr.Column(scale=2):
chat = gr.Chatbot(label="Chat")
state = gr.State([])
user_in = gr.Textbox(label="User", placeholder="Tape ton message…", lines=3)
with gr.Row():
send_btn = gr.Button("Send", variant="primary")
clear_btn = gr.Button("Clear")
debug_out = gr.Textbox(label="Debug (optional)", lines=18)
with gr.Column(scale=1):
gr.Markdown("### Model / HF")
repo = gr.Textbox(label="Repo", value="PhysiQuanty/Binary-LLM-POC")
revision = gr.Textbox(label="Revision (optional)", value="")
hf_token = gr.Textbox(label="HF_TOKEN (optional)", value="", type="password")
trust_remote_code = gr.Checkbox(label="trust_remote_code", value=True)
gr.Markdown("### Specials / base")
bos_id = gr.Number(label="bos_id", value=2, precision=0)
eos_id = gr.Number(label="eos_id", value=3, precision=0)
gr.Markdown("### Sampling")
max_new_tokens = gr.Slider(label="max_new_tokens", minimum=1, maximum=4096, value=2048, step=1)
do_sample = gr.Checkbox(label="do_sample", value=True)
temperature = gr.Slider(label="temperature", minimum=0.01, maximum=2.0, value=0.7, step=0.01)
top_p = gr.Slider(label="top_p", minimum=0.0, maximum=1.0, value=1.0, step=0.001)
top_k = gr.Slider(label="top_k", minimum=0, maximum=50, value=50, step=1)
seed = gr.Number(label="seed (-1=random)", value=-1, precision=0)
gr.Markdown("### Stops / decode")
stop_on_eos = gr.Checkbox(label="stop_on_eos", value=True)
skip_special_decode = gr.Checkbox(label="skip_special_decode", value=True)
gr.Markdown("### Penalties (5)")
repetition_penalty = gr.Slider(label="repetition_penalty", minimum=0.0, maximum=3.0, value=1.0, step=0.01)
encoder_repetition_penalty = gr.Slider(label="encoder_repetition_penalty", minimum=0.0, maximum=3.0, value=1.0, step=0.01)
presence_penalty = gr.Slider(label="presence_penalty", minimum=0.0, maximum=2.0, value=0.0, step=0.01)
frequency_penalty = gr.Slider(label="frequency_penalty", minimum=0.0, maximum=2.0, value=0.0, step=0.01)
no_repeat_ngram_size = gr.Slider(label="no_repeat_ngram_size", minimum=0, maximum=20, value=0, step=1)
gr.Markdown("### Debug / logs")
show_debug = gr.Checkbox(label="Show debug block (UI)", value=True)
show_logs = gr.Checkbox(label="Print logs to container logs", value=True)
send_btn.click(
fn=on_send,
inputs=[
user_in, state,
repo, revision, hf_token, trust_remote_code,
bos_id, eos_id,
max_new_tokens, do_sample, temperature, top_p, top_k,
seed, stop_on_eos, skip_special_decode,
repetition_penalty, presence_penalty, frequency_penalty, encoder_repetition_penalty, no_repeat_ngram_size,
show_debug, show_logs
],
outputs=[chat, state, debug_out],
)
user_in.submit(
fn=on_send,
inputs=[
user_in, state,
repo, revision, hf_token, trust_remote_code,
bos_id, eos_id,
max_new_tokens, do_sample, temperature, top_p, top_k,
seed, stop_on_eos, skip_special_decode,
repetition_penalty, presence_penalty, frequency_penalty, encoder_repetition_penalty, no_repeat_ngram_size,
show_debug, show_logs
],
outputs=[chat, state, debug_out],
)
clear_btn.click(fn=on_clear, inputs=[], outputs=[chat, state, debug_out])
send_btn.click(lambda: "", inputs=[], outputs=[user_in])
user_in.submit(lambda: "", inputs=[], outputs=[user_in])
demo.queue().launch(ssr_mode=False)