Binary-Addition-LLM-POC / inference.py
PhysiQuanty's picture
export inference-ready
feea3b3 verified
#!/usr/bin/env python3
# llmTalk_ids_v8_hf.py
# ============================================================
# INFERENCE EN IDS UNIQUEMENT (vocab=8):
# 0/1 bits + 6 specials: BOS EOS BOI EOI BOR EOR
#
# Deux modes de prompt:
# - --prompt_ids : string de chiffres (ex: "240000001540000015") (digits only, 0..7) (peut être "")
# - --prompt_int : string "int,int" -> génère: BOS t0 t1 BOI int1(10b) EOI BOI int2(10b) EOI
#
# Option:
# - --print_int : extrait le premier bloc BOR ... EOR (bits variables) dans la séquence complète
# et affiche sa valeur décimale (binaire -> int).
# (min_bits=10 par défaut pour coller à tes entrées 10 bits, mais la réponse peut dépasser)
# ============================================================
import sys
import argparse
import random
from collections import Counter
from typing import List, Dict, Tuple, Any, Optional
import torch
from transformers import AutoModelForCausalLM
# ----------------------------
# Special tokens (vocab=8)
# ----------------------------
TOK_BOS = 2
TOK_EOS = 3
TOK_BOI = 4
TOK_EOI = 5
TOK_BOR = 6
TOK_EOR = 7
TOK_NAMES = {
0: "0",
1: "1",
TOK_BOS: "BOS",
TOK_EOS: "EOS",
TOK_BOI: "BOI",
TOK_EOI: "EOI",
TOK_BOR: "BOR",
TOK_EOR: "EOR",
}
# ------------------------------------------------------------
# Task header bits for --prompt_int (t0, t1)
# ------------------------------------------------------------
# Tu as demandé "BOS t0 t1 ...", sans préciser t0/t1.
# Ici je mets un défaut neutre: 0,0 (modifiable si tu veux).
PROMPT_INT_T0 = 0
PROMPT_INT_T1 = 0
# ----------------------------
# Logits modifiers
# ----------------------------
def apply_repetition_penalty_(logits: torch.Tensor, token_ids: List[int], penalty: float) -> None:
if penalty is None or penalty == 1.0 or penalty <= 0:
return
for t in set(token_ids):
val = logits[0, t]
logits[0, t] = val * penalty if val < 0 else val / penalty
def apply_encoder_repetition_penalty_(logits: torch.Tensor, prompt_token_ids: List[int], penalty: float) -> None:
if penalty is None or penalty == 1.0 or penalty <= 0:
return
for t in set(prompt_token_ids):
val = logits[0, t]
logits[0, t] = val / penalty if val < 0 else val * penalty
def apply_presence_frequency_penalties_(
logits: torch.Tensor,
token_ids: List[int],
presence_penalty: float,
frequency_penalty: float,
) -> None:
counts = Counter(token_ids)
if presence_penalty:
for t in counts:
logits[0, t] -= presence_penalty
if frequency_penalty:
for t, c in counts.items():
logits[0, t] -= frequency_penalty * c
def get_banned_tokens_no_repeat_ngram(seq: List[int], n: int) -> set:
if n <= 0 or len(seq) < n - 1:
return set()
prefix_len = n - 1
ngrams: Dict[Tuple[int, ...], set] = {}
for i in range(len(seq) - n + 1):
prefix = tuple(seq[i:i + prefix_len])
nxt = seq[i + prefix_len]
ngrams.setdefault(prefix, set()).add(nxt)
return ngrams.get(tuple(seq[-prefix_len:]), set())
def mask_banned_tokens_(logits: torch.Tensor, banned: set) -> None:
if banned:
logits[0, list(banned)] = float("-inf")
# ----------------------------
# Helpers: prompt parsing + pretty print
# ----------------------------
def parse_prompt_ids_str(s: str, vocab_size: int = 8) -> List[int]:
s = "" if s is None else str(s)
s = s.strip()
if s == "":
return []
if not s.isdigit():
raise ValueError("prompt_ids doit contenir uniquement des chiffres (0..7), sans espaces.")
ids: List[int] = []
for ch in s:
t = ord(ch) - ord("0")
if t < 0 or t >= vocab_size:
raise ValueError(f"token id hors vocab: {t} (vocab_size={vocab_size})")
ids.append(t)
return ids
def format_ids_readable(ids: List[int]) -> str:
out: List[str] = []
for t in ids:
out.append(TOK_NAMES.get(int(t), str(int(t))))
return " ".join(out)
def format_ids_compact(ids: List[int]) -> str:
s: List[str] = []
for t in ids:
ti = int(t)
if ti in (0, 1):
if s and (s[-1] and s[-1][-1] in ("0", "1")):
s[-1] = s[-1] + str(ti)
else:
s.append(str(ti))
else:
s.append(TOK_NAMES.get(ti, str(ti)))
return " ".join(s)
# ----------------------------
# --prompt_int builder
# ----------------------------
def int_to_10bits_tokens(x: int) -> List[int]:
if x < 0 or x > 1023:
raise ValueError(f"int hors range pour 10 bits: {x} (attendu 0..1023)")
b = format(int(x), "010b") # MSB -> LSB
return [0 if ch == "0" else 1 for ch in b]
def parse_prompt_int_str(s: str) -> Tuple[int, int]:
s = "" if s is None else str(s)
s = s.strip()
if s == "":
raise ValueError("--prompt_int vide. Attendu: \"int,int\"")
parts = s.split(",")
if len(parts) != 2:
raise ValueError(f"--prompt_int invalide: {s!r}. Attendu: \"int,int\"")
try:
a = int(parts[0].strip())
b = int(parts[1].strip())
except Exception:
raise ValueError(f"--prompt_int invalide: {s!r}. Les deux valeurs doivent être des int.")
return a, b
def build_prompt_from_ints(int1: int, int2: int) -> List[int]:
seq: List[int] = []
seq.append(TOK_BOS)
seq.append(int(PROMPT_INT_T0))
seq.append(int(PROMPT_INT_T1))
seq.append(TOK_BOI)
seq.extend(int_to_10bits_tokens(int1))
seq.append(TOK_EOI)
seq.append(TOK_BOI)
seq.extend(int_to_10bits_tokens(int2))
seq.append(TOK_EOI)
return seq
# ----------------------------
# --print_int extractor (BOR ... EOR, bits variables)
# ----------------------------
def extract_first_bor_eor_bits(ids: List[int], min_bits: int = 1) -> Optional[Tuple[List[int], int, int]]:
try:
i = ids.index(TOK_BOR)
except ValueError:
return None
bits: List[int] = []
j = i + 1
while j < len(ids):
t = int(ids[j])
if t == TOK_EOR:
break
if t in (0, 1):
bits.append(t)
j += 1
if len(bits) < int(min_bits):
return None
val = 0
for b in bits:
val = (val << 1) | int(b)
return bits, val, i
# ----------------------------
# Main
# ----------------------------
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--repo", type=str, required=True, help='HF repo id ou path local (ex: "PhysiQuanty/xxx")')
parser.add_argument("--revision", type=str, default=None, help="HF revision/branch/tag/commit (optionnel)")
g = parser.add_mutually_exclusive_group(required=False)
g.add_argument("--prompt_ids", type=str, default=None, help='Ex: "240000001540000015" (digits only 0..7) or ""')
g.add_argument("--prompt_int", type=str, default=None, help='Ex: "12,900" -> BOS t0 t1 BOI 10b EOI BOI 10b EOI')
parser.add_argument("--print_int", action="store_true", help="Affiche le 1er bloc BOR..EOR (bits) en int")
parser.add_argument("--max_new_tokens", type=int, default=40)
parser.add_argument("--temperature", type=float, default=0.7)
parser.add_argument("--top_k", type=int, default=50)
parser.add_argument("--repetition_penalty", type=float, default=1.0)
parser.add_argument("--presence_penalty", type=float, default=0.0)
parser.add_argument("--frequency_penalty", type=float, default=0.0)
parser.add_argument("--encoder_repetition_penalty", type=float, default=1.0)
parser.add_argument("--no_repeat_ngram_size", type=int, default=0)
parser.add_argument("--seed", type=int, default=-1)
parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda"])
parser.add_argument("--stream_ids", action="store_true", help="Stream les IDS générés au fil de l'eau")
parser.add_argument("--print_prompt_readable", action="store_true", help="Affiche prompt en tokens lisibles")
parser.add_argument("--print_final_readable", action="store_true", help="Affiche sortie finale en tokens lisibles")
parser.add_argument("--stop_on_eos", action="store_true", help="Stop dès que EOS(3) est généré")
args = parser.parse_args()
seed = args.seed if args.seed >= 0 else random.randint(0, 2**31 - 1)
print(f"[Seed] {seed}", flush=True)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
device = torch.device("cuda" if (args.device == "cuda" and torch.cuda.is_available()) else "cpu")
print(f"[Device] {device}", flush=True)
torch_dtype = torch.float16 if device.type == "cuda" else torch.float32
model = AutoModelForCausalLM.from_pretrained(
args.repo,
revision=args.revision,
trust_remote_code=True,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
)
model.to(device)
model.eval()
vocab_size_cfg = int(getattr(model.config, "vocab_size", -1))
print(f"[Model] loaded from {args.repo} | vocab_size={vocab_size_cfg}", flush=True)
if vocab_size_cfg != 8:
print(f"[Warn] vocab_size={vocab_size_cfg} (attendu 8).", flush=True)
# ---- build prompt ids from either --prompt_int or --prompt_ids (or default "")
if args.prompt_int is not None:
int1, int2 = parse_prompt_int_str(args.prompt_int)
prompt_ids = build_prompt_from_ints(int1, int2)
prompt_origin = f'prompt_int="{args.prompt_int}" (t0,t1={PROMPT_INT_T0},{PROMPT_INT_T1})'
else:
s = "" if args.prompt_ids is None else args.prompt_ids
prompt_ids = parse_prompt_ids_str(s, vocab_size=8)
prompt_origin = 'prompt_ids' if args.prompt_ids is not None else 'prompt_ids="" (default)'
print(f"[Prompt Origin] {prompt_origin}", flush=True)
if args.print_prompt_readable:
print(f"[Prompt IDs] {prompt_ids}", flush=True)
print(f"[Prompt readable] {format_ids_readable(prompt_ids)}", flush=True)
print(f"[Prompt compact] {format_ids_compact(prompt_ids)}", flush=True)
else:
if len(prompt_ids) == 0:
print("[Prompt IDs] len=0 (prompt nul)", flush=True)
else:
print(f"[Prompt IDs] len={len(prompt_ids)} first32={prompt_ids[:32]}", flush=True)
seeded_with_bos = False
if len(prompt_ids) == 0:
tokens = torch.tensor([TOK_BOS], device=device, dtype=torch.long).unsqueeze(0)
seeded_with_bos = True
else:
tokens = torch.tensor(prompt_ids, device=device, dtype=torch.long).unsqueeze(0)
generated_raw: List[int] = []
if args.stream_ids:
sys.stdout.write("[Stream IDS] ")
sys.stdout.flush()
with torch.no_grad():
for _ in range(int(args.max_new_tokens)):
out = model(input_ids=tokens)
logits = out.logits[:, -1, :] # (1, vocab)
logits_work = logits.clone()
full_seq = tokens[0].tolist()
apply_encoder_repetition_penalty_(logits_work, prompt_ids, float(args.encoder_repetition_penalty))
apply_repetition_penalty_(logits_work, full_seq, float(args.repetition_penalty))
apply_presence_frequency_penalties_(
logits_work,
full_seq,
float(args.presence_penalty),
float(args.frequency_penalty),
)
if int(args.no_repeat_ngram_size) > 0:
banned = get_banned_tokens_no_repeat_ngram(full_seq, int(args.no_repeat_ngram_size))
mask_banned_tokens_(logits_work, banned)
logits_work /= max(float(args.temperature), 1e-6)
if 0 < int(args.top_k) < logits_work.size(-1):
v, _ = torch.topk(logits_work, int(args.top_k))
logits_work[logits_work < v[:, [-1]]] = float("-inf")
probs = torch.softmax(logits_work, dim=-1)
next_token = torch.multinomial(probs, 1) # (1,1)
tok_id = int(next_token.item())
generated_raw.append(tok_id)
if args.stream_ids:
sys.stdout.write(str(tok_id))
sys.stdout.flush()
tokens = torch.cat([tokens, next_token], dim=1)
if args.stop_on_eos and tok_id == TOK_EOS:
break
if args.stream_ids:
sys.stdout.write("\n")
sys.stdout.flush()
if seeded_with_bos:
print("\n[Prompt] prompt nul -> seed interne BOS(2) utilisé uniquement pour init logits", flush=True)
print("\n[Generated RAW IDS]", flush=True)
print(generated_raw, flush=True)
print("\n[Generated RAW IDS (as digits)]", flush=True)
print("".join(str(x) for x in generated_raw), flush=True)
if args.print_final_readable or args.print_int:
full = prompt_ids + generated_raw
if args.print_final_readable:
print("\n[Full sequence readable]", flush=True)
print(format_ids_readable(full), flush=True)
print("\n[Full sequence compact]", flush=True)
print(format_ids_compact(full), flush=True)
if args.print_int:
got = extract_first_bor_eor_bits(full, min_bits=10)
if got is None:
print("\n[PrintInt] Aucun bloc BOR..EOR valide trouvé.", flush=True)
else:
bits, val, pos = got
bits_str = "".join(str(b) for b in bits)
print("\n[PrintInt] First BOR..EOR", flush=True)
print(f"[PrintInt] pos={pos} nbits={len(bits)} bits={bits_str} int={val}", flush=True)
if __name__ == "__main__":
main()