#!/usr/bin/env python3 # llmTalk_ids_v8_hf.py # ============================================================ # INFERENCE EN IDS UNIQUEMENT (vocab=8): # 0/1 bits + 6 specials: BOS EOS BOI EOI BOR EOR # # Deux modes de prompt: # - --prompt_ids : string de chiffres (ex: "240000001540000015") (digits only, 0..7) (peut être "") # - --prompt_int : string "int,int" -> génère: BOS t0 t1 BOI int1(10b) EOI BOI int2(10b) EOI # # Option: # - --print_int : extrait le premier bloc BOR ... EOR (bits variables) dans la séquence complète # et affiche sa valeur décimale (binaire -> int). # (min_bits=10 par défaut pour coller à tes entrées 10 bits, mais la réponse peut dépasser) # ============================================================ import sys import argparse import random from collections import Counter from typing import List, Dict, Tuple, Any, Optional import torch from transformers import AutoModelForCausalLM # ---------------------------- # Special tokens (vocab=8) # ---------------------------- TOK_BOS = 2 TOK_EOS = 3 TOK_BOI = 4 TOK_EOI = 5 TOK_BOR = 6 TOK_EOR = 7 TOK_NAMES = { 0: "0", 1: "1", TOK_BOS: "BOS", TOK_EOS: "EOS", TOK_BOI: "BOI", TOK_EOI: "EOI", TOK_BOR: "BOR", TOK_EOR: "EOR", } # ------------------------------------------------------------ # Task header bits for --prompt_int (t0, t1) # ------------------------------------------------------------ # Tu as demandé "BOS t0 t1 ...", sans préciser t0/t1. # Ici je mets un défaut neutre: 0,0 (modifiable si tu veux). PROMPT_INT_T0 = 0 PROMPT_INT_T1 = 0 # ---------------------------- # Logits modifiers # ---------------------------- def apply_repetition_penalty_(logits: torch.Tensor, token_ids: List[int], penalty: float) -> None: if penalty is None or penalty == 1.0 or penalty <= 0: return for t in set(token_ids): val = logits[0, t] logits[0, t] = val * penalty if val < 0 else val / penalty def apply_encoder_repetition_penalty_(logits: torch.Tensor, prompt_token_ids: List[int], penalty: float) -> None: if penalty is None or penalty == 1.0 or penalty <= 0: return for t in set(prompt_token_ids): val = logits[0, t] logits[0, t] = val / penalty if val < 0 else val * penalty def apply_presence_frequency_penalties_( logits: torch.Tensor, token_ids: List[int], presence_penalty: float, frequency_penalty: float, ) -> None: counts = Counter(token_ids) if presence_penalty: for t in counts: logits[0, t] -= presence_penalty if frequency_penalty: for t, c in counts.items(): logits[0, t] -= frequency_penalty * c def get_banned_tokens_no_repeat_ngram(seq: List[int], n: int) -> set: if n <= 0 or len(seq) < n - 1: return set() prefix_len = n - 1 ngrams: Dict[Tuple[int, ...], set] = {} for i in range(len(seq) - n + 1): prefix = tuple(seq[i:i + prefix_len]) nxt = seq[i + prefix_len] ngrams.setdefault(prefix, set()).add(nxt) return ngrams.get(tuple(seq[-prefix_len:]), set()) def mask_banned_tokens_(logits: torch.Tensor, banned: set) -> None: if banned: logits[0, list(banned)] = float("-inf") # ---------------------------- # Helpers: prompt parsing + pretty print # ---------------------------- def parse_prompt_ids_str(s: str, vocab_size: int = 8) -> List[int]: s = "" if s is None else str(s) s = s.strip() if s == "": return [] if not s.isdigit(): raise ValueError("prompt_ids doit contenir uniquement des chiffres (0..7), sans espaces.") ids: List[int] = [] for ch in s: t = ord(ch) - ord("0") if t < 0 or t >= vocab_size: raise ValueError(f"token id hors vocab: {t} (vocab_size={vocab_size})") ids.append(t) return ids def format_ids_readable(ids: List[int]) -> str: out: List[str] = [] for t in ids: out.append(TOK_NAMES.get(int(t), str(int(t)))) return " ".join(out) def format_ids_compact(ids: List[int]) -> str: s: List[str] = [] for t in ids: ti = int(t) if ti in (0, 1): if s and (s[-1] and s[-1][-1] in ("0", "1")): s[-1] = s[-1] + str(ti) else: s.append(str(ti)) else: s.append(TOK_NAMES.get(ti, str(ti))) return " ".join(s) # ---------------------------- # --prompt_int builder # ---------------------------- def int_to_10bits_tokens(x: int) -> List[int]: if x < 0 or x > 1023: raise ValueError(f"int hors range pour 10 bits: {x} (attendu 0..1023)") b = format(int(x), "010b") # MSB -> LSB return [0 if ch == "0" else 1 for ch in b] def parse_prompt_int_str(s: str) -> Tuple[int, int]: s = "" if s is None else str(s) s = s.strip() if s == "": raise ValueError("--prompt_int vide. Attendu: \"int,int\"") parts = s.split(",") if len(parts) != 2: raise ValueError(f"--prompt_int invalide: {s!r}. Attendu: \"int,int\"") try: a = int(parts[0].strip()) b = int(parts[1].strip()) except Exception: raise ValueError(f"--prompt_int invalide: {s!r}. Les deux valeurs doivent être des int.") return a, b def build_prompt_from_ints(int1: int, int2: int) -> List[int]: seq: List[int] = [] seq.append(TOK_BOS) seq.append(int(PROMPT_INT_T0)) seq.append(int(PROMPT_INT_T1)) seq.append(TOK_BOI) seq.extend(int_to_10bits_tokens(int1)) seq.append(TOK_EOI) seq.append(TOK_BOI) seq.extend(int_to_10bits_tokens(int2)) seq.append(TOK_EOI) return seq # ---------------------------- # --print_int extractor (BOR ... EOR, bits variables) # ---------------------------- def extract_first_bor_eor_bits(ids: List[int], min_bits: int = 1) -> Optional[Tuple[List[int], int, int]]: try: i = ids.index(TOK_BOR) except ValueError: return None bits: List[int] = [] j = i + 1 while j < len(ids): t = int(ids[j]) if t == TOK_EOR: break if t in (0, 1): bits.append(t) j += 1 if len(bits) < int(min_bits): return None val = 0 for b in bits: val = (val << 1) | int(b) return bits, val, i # ---------------------------- # Main # ---------------------------- def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--repo", type=str, required=True, help='HF repo id ou path local (ex: "PhysiQuanty/xxx")') parser.add_argument("--revision", type=str, default=None, help="HF revision/branch/tag/commit (optionnel)") g = parser.add_mutually_exclusive_group(required=False) g.add_argument("--prompt_ids", type=str, default=None, help='Ex: "240000001540000015" (digits only 0..7) or ""') g.add_argument("--prompt_int", type=str, default=None, help='Ex: "12,900" -> BOS t0 t1 BOI 10b EOI BOI 10b EOI') parser.add_argument("--print_int", action="store_true", help="Affiche le 1er bloc BOR..EOR (bits) en int") parser.add_argument("--max_new_tokens", type=int, default=40) parser.add_argument("--temperature", type=float, default=0.7) parser.add_argument("--top_k", type=int, default=50) parser.add_argument("--repetition_penalty", type=float, default=1.0) parser.add_argument("--presence_penalty", type=float, default=0.0) parser.add_argument("--frequency_penalty", type=float, default=0.0) parser.add_argument("--encoder_repetition_penalty", type=float, default=1.0) parser.add_argument("--no_repeat_ngram_size", type=int, default=0) parser.add_argument("--seed", type=int, default=-1) parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda"]) parser.add_argument("--stream_ids", action="store_true", help="Stream les IDS générés au fil de l'eau") parser.add_argument("--print_prompt_readable", action="store_true", help="Affiche prompt en tokens lisibles") parser.add_argument("--print_final_readable", action="store_true", help="Affiche sortie finale en tokens lisibles") parser.add_argument("--stop_on_eos", action="store_true", help="Stop dès que EOS(3) est généré") args = parser.parse_args() seed = args.seed if args.seed >= 0 else random.randint(0, 2**31 - 1) print(f"[Seed] {seed}", flush=True) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) device = torch.device("cuda" if (args.device == "cuda" and torch.cuda.is_available()) else "cpu") print(f"[Device] {device}", flush=True) torch_dtype = torch.float16 if device.type == "cuda" else torch.float32 model = AutoModelForCausalLM.from_pretrained( args.repo, revision=args.revision, trust_remote_code=True, torch_dtype=torch_dtype, low_cpu_mem_usage=True, ) model.to(device) model.eval() vocab_size_cfg = int(getattr(model.config, "vocab_size", -1)) print(f"[Model] loaded from {args.repo} | vocab_size={vocab_size_cfg}", flush=True) if vocab_size_cfg != 8: print(f"[Warn] vocab_size={vocab_size_cfg} (attendu 8).", flush=True) # ---- build prompt ids from either --prompt_int or --prompt_ids (or default "") if args.prompt_int is not None: int1, int2 = parse_prompt_int_str(args.prompt_int) prompt_ids = build_prompt_from_ints(int1, int2) prompt_origin = f'prompt_int="{args.prompt_int}" (t0,t1={PROMPT_INT_T0},{PROMPT_INT_T1})' else: s = "" if args.prompt_ids is None else args.prompt_ids prompt_ids = parse_prompt_ids_str(s, vocab_size=8) prompt_origin = 'prompt_ids' if args.prompt_ids is not None else 'prompt_ids="" (default)' print(f"[Prompt Origin] {prompt_origin}", flush=True) if args.print_prompt_readable: print(f"[Prompt IDs] {prompt_ids}", flush=True) print(f"[Prompt readable] {format_ids_readable(prompt_ids)}", flush=True) print(f"[Prompt compact] {format_ids_compact(prompt_ids)}", flush=True) else: if len(prompt_ids) == 0: print("[Prompt IDs] len=0 (prompt nul)", flush=True) else: print(f"[Prompt IDs] len={len(prompt_ids)} first32={prompt_ids[:32]}", flush=True) seeded_with_bos = False if len(prompt_ids) == 0: tokens = torch.tensor([TOK_BOS], device=device, dtype=torch.long).unsqueeze(0) seeded_with_bos = True else: tokens = torch.tensor(prompt_ids, device=device, dtype=torch.long).unsqueeze(0) generated_raw: List[int] = [] if args.stream_ids: sys.stdout.write("[Stream IDS] ") sys.stdout.flush() with torch.no_grad(): for _ in range(int(args.max_new_tokens)): out = model(input_ids=tokens) logits = out.logits[:, -1, :] # (1, vocab) logits_work = logits.clone() full_seq = tokens[0].tolist() apply_encoder_repetition_penalty_(logits_work, prompt_ids, float(args.encoder_repetition_penalty)) apply_repetition_penalty_(logits_work, full_seq, float(args.repetition_penalty)) apply_presence_frequency_penalties_( logits_work, full_seq, float(args.presence_penalty), float(args.frequency_penalty), ) if int(args.no_repeat_ngram_size) > 0: banned = get_banned_tokens_no_repeat_ngram(full_seq, int(args.no_repeat_ngram_size)) mask_banned_tokens_(logits_work, banned) logits_work /= max(float(args.temperature), 1e-6) if 0 < int(args.top_k) < logits_work.size(-1): v, _ = torch.topk(logits_work, int(args.top_k)) logits_work[logits_work < v[:, [-1]]] = float("-inf") probs = torch.softmax(logits_work, dim=-1) next_token = torch.multinomial(probs, 1) # (1,1) tok_id = int(next_token.item()) generated_raw.append(tok_id) if args.stream_ids: sys.stdout.write(str(tok_id)) sys.stdout.flush() tokens = torch.cat([tokens, next_token], dim=1) if args.stop_on_eos and tok_id == TOK_EOS: break if args.stream_ids: sys.stdout.write("\n") sys.stdout.flush() if seeded_with_bos: print("\n[Prompt] prompt nul -> seed interne BOS(2) utilisé uniquement pour init logits", flush=True) print("\n[Generated RAW IDS]", flush=True) print(generated_raw, flush=True) print("\n[Generated RAW IDS (as digits)]", flush=True) print("".join(str(x) for x in generated_raw), flush=True) if args.print_final_readable or args.print_int: full = prompt_ids + generated_raw if args.print_final_readable: print("\n[Full sequence readable]", flush=True) print(format_ids_readable(full), flush=True) print("\n[Full sequence compact]", flush=True) print(format_ids_compact(full), flush=True) if args.print_int: got = extract_first_bor_eor_bits(full, min_bits=10) if got is None: print("\n[PrintInt] Aucun bloc BOR..EOR valide trouvé.", flush=True) else: bits, val, pos = got bits_str = "".join(str(b) for b in bits) print("\n[PrintInt] First BOR..EOR", flush=True) print(f"[PrintInt] pos={pos} nbits={len(bits)} bits={bits_str} int={val}", flush=True) if __name__ == "__main__": main()