PhysiQuanty commited on
Commit
4a9fbd4
·
verified ·
1 Parent(s): 333c419

inference-ready export

Browse files
Files changed (1) hide show
  1. inference.py +249 -0
inference.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # infer.py
3
+ # ============================================================
4
+ # HF inference (CausalLM) en base-2
5
+ # - Encode le --prompt en bits (MSB->LSB) comme llmTalk
6
+ # - Prompt = [BOS] + bits + [EOS] + [BOS] (reset latent)
7
+ # - PAS de KV-cache (use_cache=False) => "comme entraînement" (full forward)
8
+ # - Boucle manuelle token-par-token (pas model.generate)
9
+ # - Décodage FINAL via decode_base2_digits_strict
10
+ # - indentation AVEC TABULATIONS (comme ton fichier actuel)
11
+ # ============================================================
12
+
13
+ import sys
14
+ import os
15
+ import argparse
16
+ import random
17
+ import codecs
18
+ from typing import List, Dict
19
+ from collections import Counter
20
+
21
+ import torch
22
+ from transformers import AutoModelForCausalLM
23
+
24
+
25
+ def decode_base2_digits_strict(digits: List[int], *, encoding: str = "utf-8", errors: str = "replace") -> str:
26
+ # Filtre minimal: ne garder que 0/1 (au cas où)
27
+ bits: List[int] = []
28
+ for d in digits:
29
+ di = int(d)
30
+ if di == 0 or di == 1:
31
+ bits.append(di)
32
+
33
+ n_full_bytes = len(bits) // 8
34
+ if n_full_bytes <= 0:
35
+ return ""
36
+
37
+ out = bytearray(n_full_bytes)
38
+
39
+ j = 0
40
+ for i in range(n_full_bytes):
41
+ # MSB -> LSB (bits[j] est le bit de poids fort)
42
+ b = 0
43
+ b = (b << 1) | bits[j + 0]
44
+ b = (b << 1) | bits[j + 1]
45
+ b = (b << 1) | bits[j + 2]
46
+ b = (b << 1) | bits[j + 3]
47
+ b = (b << 1) | bits[j + 4]
48
+ b = (b << 1) | bits[j + 5]
49
+ b = (b << 1) | bits[j + 6]
50
+ b = (b << 1) | bits[j + 7]
51
+ out[i] = b
52
+ j += 8
53
+
54
+ bb = bytes(out)
55
+
56
+ # Décodage robuste UTF-8 (gère proprement les séquences multi-octets)
57
+ if encoding.lower() == "utf-8":
58
+ inc = codecs.getincrementaldecoder("utf-8")(errors=errors)
59
+ s = inc.decode(bb, final=False)
60
+ s += inc.decode(b"", final=True)
61
+ return s
62
+
63
+ return bb.decode(encoding, errors=errors)
64
+
65
+
66
+ def bytes_to_base2_digits_bytesafe(data: bytes) -> List[int]:
67
+ digits: List[int] = []
68
+ for b in data:
69
+ for i in range(7, -1, -1):
70
+ digits.append((b >> i) & 1)
71
+ return digits
72
+
73
+
74
+ def text_to_base2_digits(text: str) -> List[int]:
75
+ # Même logique que llmTalk: UTF-8 -> bits MSB->LSB
76
+ return bytes_to_base2_digits_bytesafe(text.encode("utf-8"))
77
+
78
+
79
+ def wrap_base2_sequence_2(ids: List[int], bos_id: int, eos_id: int) -> List[int]:
80
+ return [int(bos_id), *ids, int(eos_id)]
81
+
82
+
83
+ def apply_repetition_penalty_(logits: torch.Tensor, token_ids: List[int], penalty: float) -> None:
84
+ if penalty is None or penalty == 1.0 or penalty <= 0:
85
+ return
86
+ for t in set(token_ids):
87
+ val = logits[0, t]
88
+ logits[0, t] = val * penalty if val < 0 else val / penalty
89
+
90
+
91
+ def apply_presence_frequency_penalties_(logits: torch.Tensor, token_ids: List[int], presence_penalty: float, frequency_penalty: float) -> None:
92
+ counts = Counter(token_ids)
93
+ if presence_penalty:
94
+ for t in counts:
95
+ logits[0, t] -= presence_penalty
96
+ if frequency_penalty:
97
+ for t, c in counts.items():
98
+ logits[0, t] -= frequency_penalty * c
99
+
100
+
101
+ def get_banned_tokens_no_repeat_ngram(seq: List[int], n: int) -> set:
102
+ if n <= 0 or len(seq) < n - 1:
103
+ return set()
104
+
105
+ prefix_len = n - 1
106
+ ngrams: Dict[tuple, set] = {}
107
+ for i in range(len(seq) - n + 1):
108
+ prefix = tuple(seq[i:i + prefix_len])
109
+ nxt = seq[i + prefix_len]
110
+ ngrams.setdefault(prefix, set()).add(nxt)
111
+
112
+ return ngrams.get(tuple(seq[-prefix_len:]), set())
113
+
114
+
115
+ def mask_banned_tokens_(logits: torch.Tensor, banned: set) -> None:
116
+ if banned:
117
+ logits[0, list(banned)] = float("-inf")
118
+
119
+
120
+ def _maybe_hf_token() -> str:
121
+ tok = os.environ.get("HF_TOKEN")
122
+ if tok:
123
+ return tok
124
+ tok = os.environ.get("HUGGINGFACE_HUB_TOKEN")
125
+ if tok:
126
+ return tok
127
+ return ""
128
+
129
+
130
+ def main() -> None:
131
+ parser = argparse.ArgumentParser()
132
+
133
+ parser.add_argument("--repo", type=str, required=True, help="chemin dossier HF local (./hf_binaryllm_repo) ou repo_id")
134
+ parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda"])
135
+ parser.add_argument("--seed", type=int, default=-1)
136
+
137
+ # Base-2 avec 2 spéciaux => vocab_size=4 attendu: 0,1 + BOS=2 + EOS=3
138
+ parser.add_argument("--bos", type=int, default=2, help="BOS id (base2: BOS=2)")
139
+ parser.add_argument("--eos", type=int, default=3, help="EOS id (base2: EOS=3)")
140
+ parser.add_argument("--prompt", type=str, required=True, help="texte à encoder en base2 (UTF-8 -> bits MSB->LSB)")
141
+
142
+ parser.add_argument("--max_new_tokens", type=int, default=800)
143
+ parser.add_argument("--temperature", type=float, default=0.7)
144
+ parser.add_argument("--top_k", type=int, default=50)
145
+
146
+ parser.add_argument("--repetition_penalty", type=float, default=1.0)
147
+ parser.add_argument("--presence_penalty", type=float, default=0.0)
148
+ parser.add_argument("--frequency_penalty", type=float, default=0.0)
149
+ parser.add_argument("--no_repeat_ngram_size", type=int, default=0)
150
+
151
+ parser.add_argument("--decode_encoding", type=str, default="utf-8")
152
+ parser.add_argument("--decode_errors", type=str, default="replace")
153
+ parser.add_argument("--print_ids", action="store_true")
154
+ parser.add_argument("--stream", action="store_true", help="stream strict (réaffiche decode strict à chaque step)")
155
+
156
+ args = parser.parse_args()
157
+
158
+ seed = args.seed if args.seed >= 0 else random.randint(0, 2**31 - 1)
159
+ print(f"[Seed] {seed}")
160
+ torch.manual_seed(seed)
161
+ if torch.cuda.is_available():
162
+ torch.cuda.manual_seed_all(seed)
163
+
164
+ device = torch.device("cuda" if (args.device == "cuda" and torch.cuda.is_available()) else "cpu")
165
+ print(f"[Device] {device}")
166
+
167
+ # --------- Load HF model ---------
168
+ hf_token = _maybe_hf_token()
169
+ if hf_token:
170
+ m = AutoModelForCausalLM.from_pretrained(args.repo, trust_remote_code=True, token=hf_token)
171
+ else:
172
+ m = AutoModelForCausalLM.from_pretrained(args.repo, trust_remote_code=True)
173
+
174
+ m.to(device)
175
+ m.eval()
176
+
177
+ # IMPORTANT: pas de KV-cache (train-like)
178
+ if hasattr(m, "config") and m.config is not None:
179
+ m.config.use_cache = False
180
+
181
+ # --------- Encode prompt EXACTEMENT comme llmTalk (base=2) ---------
182
+ def encode_prompt(text: str) -> List[int]:
183
+ ids = text_to_base2_digits(text) # 0/1 bits (MSB->LSB)
184
+ ids = wrap_base2_sequence_2(ids, args.bos, args.eos) # [BOS] bits [EOS]
185
+ ids = ids + [int(args.bos)] # reset latent: ...[EOS][BOS]
186
+ print("[+] IDS = ", ids) # debug (tu supprimeras avant commit)
187
+ return ids
188
+
189
+ prompt_ids = encode_prompt(args.prompt)
190
+
191
+ tokens = torch.tensor([prompt_ids], dtype=torch.long, device=device)
192
+ generated: List[int] = []
193
+ last_text_len = 0
194
+
195
+ print("\n[Prompt]\n", args.prompt)
196
+ print(f"\n[Prompt IDs] len={len(prompt_ids)} | BOS={args.bos} EOS={args.eos}")
197
+ print("\n[Stream]" if args.stream else "\n[Output]")
198
+
199
+ with torch.no_grad():
200
+ for _ in range(int(args.max_new_tokens)):
201
+ # full forward sur toute la séquence, sans cache
202
+ out = m(input_ids=tokens, use_cache=False)
203
+ logits = out.logits[:, -1, :]
204
+
205
+ full_seq = tokens[0].tolist()
206
+
207
+ apply_repetition_penalty_(logits, full_seq, float(args.repetition_penalty))
208
+ apply_presence_frequency_penalties_(logits, full_seq, float(args.presence_penalty), float(args.frequency_penalty))
209
+
210
+ if int(args.no_repeat_ngram_size) > 0:
211
+ banned = get_banned_tokens_no_repeat_ngram(full_seq, int(args.no_repeat_ngram_size))
212
+ mask_banned_tokens_(logits, banned)
213
+
214
+ logits = logits / max(float(args.temperature), 1e-6)
215
+
216
+ if 0 < int(args.top_k) < logits.size(-1):
217
+ v, _ = torch.topk(logits, int(args.top_k))
218
+ logits[logits < v[:, [-1]]] = float("-inf")
219
+
220
+ probs = torch.softmax(logits, dim=-1)
221
+ next_token = torch.multinomial(probs, 1)
222
+ tok_id = int(next_token.item())
223
+
224
+ if tok_id == int(args.eos):
225
+ break
226
+
227
+ tokens = torch.cat([tokens, next_token], dim=1)
228
+ generated.append(tok_id)
229
+
230
+ if args.stream:
231
+ text = decode_base2_digits_strict(generated, encoding=args.decode_encoding, errors=args.decode_errors)
232
+ if len(text) > last_text_len:
233
+ sys.stdout.write(text[last_text_len:])
234
+ sys.stdout.flush()
235
+ last_text_len = len(text)
236
+
237
+ if args.stream:
238
+ print()
239
+
240
+ print("\n[Final Output]\n")
241
+ print(decode_base2_digits_strict(generated, encoding=args.decode_encoding, errors=args.decode_errors))
242
+
243
+ if args.print_ids:
244
+ print("\n[Generated IDs]\n")
245
+ print(generated)
246
+
247
+
248
+ if __name__ == "__main__":
249
+ main()