PhysiQuanty commited on
Commit
feea3b3
·
verified ·
1 Parent(s): a6c916c

export inference-ready

Browse files
Files changed (1) hide show
  1. inference.py +385 -0
inference.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # llmTalk_ids_v8_hf.py
3
+ # ============================================================
4
+ # INFERENCE EN IDS UNIQUEMENT (vocab=8):
5
+ # 0/1 bits + 6 specials: BOS EOS BOI EOI BOR EOR
6
+ #
7
+ # Deux modes de prompt:
8
+ # - --prompt_ids : string de chiffres (ex: "240000001540000015") (digits only, 0..7) (peut être "")
9
+ # - --prompt_int : string "int,int" -> génère: BOS t0 t1 BOI int1(10b) EOI BOI int2(10b) EOI
10
+ #
11
+ # Option:
12
+ # - --print_int : extrait le premier bloc BOR ... EOR (bits variables) dans la séquence complète
13
+ # et affiche sa valeur décimale (binaire -> int).
14
+ # (min_bits=10 par défaut pour coller à tes entrées 10 bits, mais la réponse peut dépasser)
15
+ # ============================================================
16
+
17
+ import sys
18
+ import argparse
19
+ import random
20
+ from collections import Counter
21
+ from typing import List, Dict, Tuple, Any, Optional
22
+
23
+ import torch
24
+ from transformers import AutoModelForCausalLM
25
+
26
+ # ----------------------------
27
+ # Special tokens (vocab=8)
28
+ # ----------------------------
29
+ TOK_BOS = 2
30
+ TOK_EOS = 3
31
+ TOK_BOI = 4
32
+ TOK_EOI = 5
33
+ TOK_BOR = 6
34
+ TOK_EOR = 7
35
+
36
+ TOK_NAMES = {
37
+ 0: "0",
38
+ 1: "1",
39
+ TOK_BOS: "BOS",
40
+ TOK_EOS: "EOS",
41
+ TOK_BOI: "BOI",
42
+ TOK_EOI: "EOI",
43
+ TOK_BOR: "BOR",
44
+ TOK_EOR: "EOR",
45
+ }
46
+
47
+ # ------------------------------------------------------------
48
+ # Task header bits for --prompt_int (t0, t1)
49
+ # ------------------------------------------------------------
50
+ # Tu as demandé "BOS t0 t1 ...", sans préciser t0/t1.
51
+ # Ici je mets un défaut neutre: 0,0 (modifiable si tu veux).
52
+ PROMPT_INT_T0 = 0
53
+ PROMPT_INT_T1 = 0
54
+
55
+ # ----------------------------
56
+ # Logits modifiers
57
+ # ----------------------------
58
+ def apply_repetition_penalty_(logits: torch.Tensor, token_ids: List[int], penalty: float) -> None:
59
+ if penalty is None or penalty == 1.0 or penalty <= 0:
60
+ return
61
+ for t in set(token_ids):
62
+ val = logits[0, t]
63
+ logits[0, t] = val * penalty if val < 0 else val / penalty
64
+
65
+ def apply_encoder_repetition_penalty_(logits: torch.Tensor, prompt_token_ids: List[int], penalty: float) -> None:
66
+ if penalty is None or penalty == 1.0 or penalty <= 0:
67
+ return
68
+ for t in set(prompt_token_ids):
69
+ val = logits[0, t]
70
+ logits[0, t] = val / penalty if val < 0 else val * penalty
71
+
72
+ def apply_presence_frequency_penalties_(
73
+ logits: torch.Tensor,
74
+ token_ids: List[int],
75
+ presence_penalty: float,
76
+ frequency_penalty: float,
77
+ ) -> None:
78
+ counts = Counter(token_ids)
79
+
80
+ if presence_penalty:
81
+ for t in counts:
82
+ logits[0, t] -= presence_penalty
83
+
84
+ if frequency_penalty:
85
+ for t, c in counts.items():
86
+ logits[0, t] -= frequency_penalty * c
87
+
88
+ def get_banned_tokens_no_repeat_ngram(seq: List[int], n: int) -> set:
89
+ if n <= 0 or len(seq) < n - 1:
90
+ return set()
91
+
92
+ prefix_len = n - 1
93
+ ngrams: Dict[Tuple[int, ...], set] = {}
94
+ for i in range(len(seq) - n + 1):
95
+ prefix = tuple(seq[i:i + prefix_len])
96
+ nxt = seq[i + prefix_len]
97
+ ngrams.setdefault(prefix, set()).add(nxt)
98
+
99
+ return ngrams.get(tuple(seq[-prefix_len:]), set())
100
+
101
+ def mask_banned_tokens_(logits: torch.Tensor, banned: set) -> None:
102
+ if banned:
103
+ logits[0, list(banned)] = float("-inf")
104
+
105
+ # ----------------------------
106
+ # Helpers: prompt parsing + pretty print
107
+ # ----------------------------
108
+ def parse_prompt_ids_str(s: str, vocab_size: int = 8) -> List[int]:
109
+ s = "" if s is None else str(s)
110
+ s = s.strip()
111
+ if s == "":
112
+ return []
113
+
114
+ if not s.isdigit():
115
+ raise ValueError("prompt_ids doit contenir uniquement des chiffres (0..7), sans espaces.")
116
+
117
+ ids: List[int] = []
118
+ for ch in s:
119
+ t = ord(ch) - ord("0")
120
+ if t < 0 or t >= vocab_size:
121
+ raise ValueError(f"token id hors vocab: {t} (vocab_size={vocab_size})")
122
+ ids.append(t)
123
+ return ids
124
+
125
+ def format_ids_readable(ids: List[int]) -> str:
126
+ out: List[str] = []
127
+ for t in ids:
128
+ out.append(TOK_NAMES.get(int(t), str(int(t))))
129
+ return " ".join(out)
130
+
131
+ def format_ids_compact(ids: List[int]) -> str:
132
+ s: List[str] = []
133
+ for t in ids:
134
+ ti = int(t)
135
+ if ti in (0, 1):
136
+ if s and (s[-1] and s[-1][-1] in ("0", "1")):
137
+ s[-1] = s[-1] + str(ti)
138
+ else:
139
+ s.append(str(ti))
140
+ else:
141
+ s.append(TOK_NAMES.get(ti, str(ti)))
142
+ return " ".join(s)
143
+
144
+ # ----------------------------
145
+ # --prompt_int builder
146
+ # ----------------------------
147
+ def int_to_10bits_tokens(x: int) -> List[int]:
148
+ if x < 0 or x > 1023:
149
+ raise ValueError(f"int hors range pour 10 bits: {x} (attendu 0..1023)")
150
+ b = format(int(x), "010b") # MSB -> LSB
151
+ return [0 if ch == "0" else 1 for ch in b]
152
+
153
+ def parse_prompt_int_str(s: str) -> Tuple[int, int]:
154
+ s = "" if s is None else str(s)
155
+ s = s.strip()
156
+ if s == "":
157
+ raise ValueError("--prompt_int vide. Attendu: \"int,int\"")
158
+
159
+ parts = s.split(",")
160
+ if len(parts) != 2:
161
+ raise ValueError(f"--prompt_int invalide: {s!r}. Attendu: \"int,int\"")
162
+
163
+ try:
164
+ a = int(parts[0].strip())
165
+ b = int(parts[1].strip())
166
+ except Exception:
167
+ raise ValueError(f"--prompt_int invalide: {s!r}. Les deux valeurs doivent être des int.")
168
+
169
+ return a, b
170
+
171
+ def build_prompt_from_ints(int1: int, int2: int) -> List[int]:
172
+ seq: List[int] = []
173
+ seq.append(TOK_BOS)
174
+ seq.append(int(PROMPT_INT_T0))
175
+ seq.append(int(PROMPT_INT_T1))
176
+
177
+ seq.append(TOK_BOI)
178
+ seq.extend(int_to_10bits_tokens(int1))
179
+ seq.append(TOK_EOI)
180
+
181
+ seq.append(TOK_BOI)
182
+ seq.extend(int_to_10bits_tokens(int2))
183
+ seq.append(TOK_EOI)
184
+
185
+ return seq
186
+
187
+ # ----------------------------
188
+ # --print_int extractor (BOR ... EOR, bits variables)
189
+ # ----------------------------
190
+ def extract_first_bor_eor_bits(ids: List[int], min_bits: int = 1) -> Optional[Tuple[List[int], int, int]]:
191
+ try:
192
+ i = ids.index(TOK_BOR)
193
+ except ValueError:
194
+ return None
195
+
196
+ bits: List[int] = []
197
+ j = i + 1
198
+ while j < len(ids):
199
+ t = int(ids[j])
200
+ if t == TOK_EOR:
201
+ break
202
+ if t in (0, 1):
203
+ bits.append(t)
204
+ j += 1
205
+
206
+ if len(bits) < int(min_bits):
207
+ return None
208
+
209
+ val = 0
210
+ for b in bits:
211
+ val = (val << 1) | int(b)
212
+
213
+ return bits, val, i
214
+
215
+ # ----------------------------
216
+ # Main
217
+ # ----------------------------
218
+ def main() -> None:
219
+ parser = argparse.ArgumentParser()
220
+
221
+ parser.add_argument("--repo", type=str, required=True, help='HF repo id ou path local (ex: "PhysiQuanty/xxx")')
222
+ parser.add_argument("--revision", type=str, default=None, help="HF revision/branch/tag/commit (optionnel)")
223
+
224
+ g = parser.add_mutually_exclusive_group(required=False)
225
+ g.add_argument("--prompt_ids", type=str, default=None, help='Ex: "240000001540000015" (digits only 0..7) or ""')
226
+ g.add_argument("--prompt_int", type=str, default=None, help='Ex: "12,900" -> BOS t0 t1 BOI 10b EOI BOI 10b EOI')
227
+
228
+ parser.add_argument("--print_int", action="store_true", help="Affiche le 1er bloc BOR..EOR (bits) en int")
229
+
230
+ parser.add_argument("--max_new_tokens", type=int, default=40)
231
+ parser.add_argument("--temperature", type=float, default=0.7)
232
+ parser.add_argument("--top_k", type=int, default=50)
233
+
234
+ parser.add_argument("--repetition_penalty", type=float, default=1.0)
235
+ parser.add_argument("--presence_penalty", type=float, default=0.0)
236
+ parser.add_argument("--frequency_penalty", type=float, default=0.0)
237
+ parser.add_argument("--encoder_repetition_penalty", type=float, default=1.0)
238
+ parser.add_argument("--no_repeat_ngram_size", type=int, default=0)
239
+
240
+ parser.add_argument("--seed", type=int, default=-1)
241
+ parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda"])
242
+
243
+ parser.add_argument("--stream_ids", action="store_true", help="Stream les IDS générés au fil de l'eau")
244
+ parser.add_argument("--print_prompt_readable", action="store_true", help="Affiche prompt en tokens lisibles")
245
+ parser.add_argument("--print_final_readable", action="store_true", help="Affiche sortie finale en tokens lisibles")
246
+ parser.add_argument("--stop_on_eos", action="store_true", help="Stop dès que EOS(3) est généré")
247
+
248
+ args = parser.parse_args()
249
+
250
+ seed = args.seed if args.seed >= 0 else random.randint(0, 2**31 - 1)
251
+ print(f"[Seed] {seed}", flush=True)
252
+ torch.manual_seed(seed)
253
+ if torch.cuda.is_available():
254
+ torch.cuda.manual_seed_all(seed)
255
+
256
+ device = torch.device("cuda" if (args.device == "cuda" and torch.cuda.is_available()) else "cpu")
257
+ print(f"[Device] {device}", flush=True)
258
+
259
+ torch_dtype = torch.float16 if device.type == "cuda" else torch.float32
260
+
261
+ model = AutoModelForCausalLM.from_pretrained(
262
+ args.repo,
263
+ revision=args.revision,
264
+ trust_remote_code=True,
265
+ torch_dtype=torch_dtype,
266
+ low_cpu_mem_usage=True,
267
+ )
268
+ model.to(device)
269
+ model.eval()
270
+
271
+ vocab_size_cfg = int(getattr(model.config, "vocab_size", -1))
272
+ print(f"[Model] loaded from {args.repo} | vocab_size={vocab_size_cfg}", flush=True)
273
+ if vocab_size_cfg != 8:
274
+ print(f"[Warn] vocab_size={vocab_size_cfg} (attendu 8).", flush=True)
275
+
276
+ # ---- build prompt ids from either --prompt_int or --prompt_ids (or default "")
277
+ if args.prompt_int is not None:
278
+ int1, int2 = parse_prompt_int_str(args.prompt_int)
279
+ prompt_ids = build_prompt_from_ints(int1, int2)
280
+ prompt_origin = f'prompt_int="{args.prompt_int}" (t0,t1={PROMPT_INT_T0},{PROMPT_INT_T1})'
281
+ else:
282
+ s = "" if args.prompt_ids is None else args.prompt_ids
283
+ prompt_ids = parse_prompt_ids_str(s, vocab_size=8)
284
+ prompt_origin = 'prompt_ids' if args.prompt_ids is not None else 'prompt_ids="" (default)'
285
+
286
+ print(f"[Prompt Origin] {prompt_origin}", flush=True)
287
+
288
+ if args.print_prompt_readable:
289
+ print(f"[Prompt IDs] {prompt_ids}", flush=True)
290
+ print(f"[Prompt readable] {format_ids_readable(prompt_ids)}", flush=True)
291
+ print(f"[Prompt compact] {format_ids_compact(prompt_ids)}", flush=True)
292
+ else:
293
+ if len(prompt_ids) == 0:
294
+ print("[Prompt IDs] len=0 (prompt nul)", flush=True)
295
+ else:
296
+ print(f"[Prompt IDs] len={len(prompt_ids)} first32={prompt_ids[:32]}", flush=True)
297
+
298
+ seeded_with_bos = False
299
+ if len(prompt_ids) == 0:
300
+ tokens = torch.tensor([TOK_BOS], device=device, dtype=torch.long).unsqueeze(0)
301
+ seeded_with_bos = True
302
+ else:
303
+ tokens = torch.tensor(prompt_ids, device=device, dtype=torch.long).unsqueeze(0)
304
+
305
+ generated_raw: List[int] = []
306
+
307
+ if args.stream_ids:
308
+ sys.stdout.write("[Stream IDS] ")
309
+ sys.stdout.flush()
310
+
311
+ with torch.no_grad():
312
+ for _ in range(int(args.max_new_tokens)):
313
+ out = model(input_ids=tokens)
314
+ logits = out.logits[:, -1, :] # (1, vocab)
315
+
316
+ logits_work = logits.clone()
317
+ full_seq = tokens[0].tolist()
318
+
319
+ apply_encoder_repetition_penalty_(logits_work, prompt_ids, float(args.encoder_repetition_penalty))
320
+ apply_repetition_penalty_(logits_work, full_seq, float(args.repetition_penalty))
321
+ apply_presence_frequency_penalties_(
322
+ logits_work,
323
+ full_seq,
324
+ float(args.presence_penalty),
325
+ float(args.frequency_penalty),
326
+ )
327
+
328
+ if int(args.no_repeat_ngram_size) > 0:
329
+ banned = get_banned_tokens_no_repeat_ngram(full_seq, int(args.no_repeat_ngram_size))
330
+ mask_banned_tokens_(logits_work, banned)
331
+
332
+ logits_work /= max(float(args.temperature), 1e-6)
333
+
334
+ if 0 < int(args.top_k) < logits_work.size(-1):
335
+ v, _ = torch.topk(logits_work, int(args.top_k))
336
+ logits_work[logits_work < v[:, [-1]]] = float("-inf")
337
+
338
+ probs = torch.softmax(logits_work, dim=-1)
339
+ next_token = torch.multinomial(probs, 1) # (1,1)
340
+ tok_id = int(next_token.item())
341
+ generated_raw.append(tok_id)
342
+
343
+ if args.stream_ids:
344
+ sys.stdout.write(str(tok_id))
345
+ sys.stdout.flush()
346
+
347
+ tokens = torch.cat([tokens, next_token], dim=1)
348
+
349
+ if args.stop_on_eos and tok_id == TOK_EOS:
350
+ break
351
+
352
+ if args.stream_ids:
353
+ sys.stdout.write("\n")
354
+ sys.stdout.flush()
355
+
356
+ if seeded_with_bos:
357
+ print("\n[Prompt] prompt nul -> seed interne BOS(2) utilisé uniquement pour init logits", flush=True)
358
+
359
+ print("\n[Generated RAW IDS]", flush=True)
360
+ print(generated_raw, flush=True)
361
+
362
+ print("\n[Generated RAW IDS (as digits)]", flush=True)
363
+ print("".join(str(x) for x in generated_raw), flush=True)
364
+
365
+ if args.print_final_readable or args.print_int:
366
+ full = prompt_ids + generated_raw
367
+
368
+ if args.print_final_readable:
369
+ print("\n[Full sequence readable]", flush=True)
370
+ print(format_ids_readable(full), flush=True)
371
+ print("\n[Full sequence compact]", flush=True)
372
+ print(format_ids_compact(full), flush=True)
373
+
374
+ if args.print_int:
375
+ got = extract_first_bor_eor_bits(full, min_bits=10)
376
+ if got is None:
377
+ print("\n[PrintInt] Aucun bloc BOR..EOR valide trouvé.", flush=True)
378
+ else:
379
+ bits, val, pos = got
380
+ bits_str = "".join(str(b) for b in bits)
381
+ print("\n[PrintInt] First BOR..EOR", flush=True)
382
+ print(f"[PrintInt] pos={pos} nbits={len(bits)} bits={bits_str} int={val}", flush=True)
383
+
384
+ if __name__ == "__main__":
385
+ main()