|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import sys |
|
|
import os |
|
|
import argparse |
|
|
import random |
|
|
import codecs |
|
|
from typing import List, Dict |
|
|
from collections import Counter |
|
|
|
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM |
|
|
|
|
|
|
|
|
def decode_base2_digits_strict(digits: List[int], *, encoding: str = "utf-8", errors: str = "replace") -> str: |
|
|
|
|
|
bits: List[int] = [] |
|
|
for d in digits: |
|
|
di = int(d) |
|
|
if di == 0 or di == 1: |
|
|
bits.append(di) |
|
|
|
|
|
n_full_bytes = len(bits) // 8 |
|
|
if n_full_bytes <= 0: |
|
|
return "" |
|
|
|
|
|
out = bytearray(n_full_bytes) |
|
|
|
|
|
j = 0 |
|
|
for i in range(n_full_bytes): |
|
|
|
|
|
b = 0 |
|
|
b = (b << 1) | bits[j + 0] |
|
|
b = (b << 1) | bits[j + 1] |
|
|
b = (b << 1) | bits[j + 2] |
|
|
b = (b << 1) | bits[j + 3] |
|
|
b = (b << 1) | bits[j + 4] |
|
|
b = (b << 1) | bits[j + 5] |
|
|
b = (b << 1) | bits[j + 6] |
|
|
b = (b << 1) | bits[j + 7] |
|
|
out[i] = b |
|
|
j += 8 |
|
|
|
|
|
bb = bytes(out) |
|
|
|
|
|
|
|
|
if encoding.lower() == "utf-8": |
|
|
inc = codecs.getincrementaldecoder("utf-8")(errors=errors) |
|
|
s = inc.decode(bb, final=False) |
|
|
s += inc.decode(b"", final=True) |
|
|
return s |
|
|
|
|
|
return bb.decode(encoding, errors=errors) |
|
|
|
|
|
|
|
|
def bytes_to_base2_digits_bytesafe(data: bytes) -> List[int]: |
|
|
digits: List[int] = [] |
|
|
for b in data: |
|
|
for i in range(7, -1, -1): |
|
|
digits.append((b >> i) & 1) |
|
|
return digits |
|
|
|
|
|
|
|
|
def text_to_base2_digits(text: str) -> List[int]: |
|
|
|
|
|
return bytes_to_base2_digits_bytesafe(text.encode("utf-8")) |
|
|
|
|
|
|
|
|
def wrap_base2_sequence_2(ids: List[int], bos_id: int, eos_id: int) -> List[int]: |
|
|
return [int(bos_id), *ids, int(eos_id)] |
|
|
|
|
|
|
|
|
def apply_repetition_penalty_(logits: torch.Tensor, token_ids: List[int], penalty: float) -> None: |
|
|
if penalty is None or penalty == 1.0 or penalty <= 0: |
|
|
return |
|
|
for t in set(token_ids): |
|
|
val = logits[0, t] |
|
|
logits[0, t] = val * penalty if val < 0 else val / penalty |
|
|
|
|
|
|
|
|
def apply_presence_frequency_penalties_(logits: torch.Tensor, token_ids: List[int], presence_penalty: float, frequency_penalty: float) -> None: |
|
|
counts = Counter(token_ids) |
|
|
if presence_penalty: |
|
|
for t in counts: |
|
|
logits[0, t] -= presence_penalty |
|
|
if frequency_penalty: |
|
|
for t, c in counts.items(): |
|
|
logits[0, t] -= frequency_penalty * c |
|
|
|
|
|
|
|
|
def get_banned_tokens_no_repeat_ngram(seq: List[int], n: int) -> set: |
|
|
if n <= 0 or len(seq) < n - 1: |
|
|
return set() |
|
|
|
|
|
prefix_len = n - 1 |
|
|
ngrams: Dict[tuple, set] = {} |
|
|
for i in range(len(seq) - n + 1): |
|
|
prefix = tuple(seq[i:i + prefix_len]) |
|
|
nxt = seq[i + prefix_len] |
|
|
ngrams.setdefault(prefix, set()).add(nxt) |
|
|
|
|
|
return ngrams.get(tuple(seq[-prefix_len:]), set()) |
|
|
|
|
|
|
|
|
def mask_banned_tokens_(logits: torch.Tensor, banned: set) -> None: |
|
|
if banned: |
|
|
logits[0, list(banned)] = float("-inf") |
|
|
|
|
|
|
|
|
def _maybe_hf_token() -> str: |
|
|
tok = os.environ.get("HF_TOKEN") |
|
|
if tok: |
|
|
return tok |
|
|
tok = os.environ.get("HUGGINGFACE_HUB_TOKEN") |
|
|
if tok: |
|
|
return tok |
|
|
return "" |
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
parser.add_argument("--repo", type=str, required=True, help="chemin dossier HF local (./hf_binaryllm_repo) ou repo_id") |
|
|
parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda"]) |
|
|
parser.add_argument("--seed", type=int, default=-1) |
|
|
|
|
|
|
|
|
parser.add_argument("--bos", type=int, default=2, help="BOS id (base2: BOS=2)") |
|
|
parser.add_argument("--eos", type=int, default=3, help="EOS id (base2: EOS=3)") |
|
|
parser.add_argument("--prompt", type=str, required=True, help="texte à encoder en base2 (UTF-8 -> bits MSB->LSB)") |
|
|
|
|
|
parser.add_argument("--max_new_tokens", type=int, default=800) |
|
|
parser.add_argument("--temperature", type=float, default=0.7) |
|
|
parser.add_argument("--top_k", type=int, default=50) |
|
|
|
|
|
parser.add_argument("--repetition_penalty", type=float, default=1.0) |
|
|
parser.add_argument("--presence_penalty", type=float, default=0.0) |
|
|
parser.add_argument("--frequency_penalty", type=float, default=0.0) |
|
|
parser.add_argument("--no_repeat_ngram_size", type=int, default=0) |
|
|
|
|
|
parser.add_argument("--decode_encoding", type=str, default="utf-8") |
|
|
parser.add_argument("--decode_errors", type=str, default="replace") |
|
|
parser.add_argument("--print_ids", action="store_true") |
|
|
parser.add_argument("--stream", action="store_true", help="stream strict (réaffiche decode strict à chaque step)") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
seed = args.seed if args.seed >= 0 else random.randint(0, 2**31 - 1) |
|
|
print(f"[Seed] {seed}") |
|
|
torch.manual_seed(seed) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
device = torch.device("cuda" if (args.device == "cuda" and torch.cuda.is_available()) else "cpu") |
|
|
print(f"[Device] {device}") |
|
|
|
|
|
|
|
|
hf_token = _maybe_hf_token() |
|
|
if hf_token: |
|
|
m = AutoModelForCausalLM.from_pretrained(args.repo, trust_remote_code=True, token=hf_token) |
|
|
else: |
|
|
m = AutoModelForCausalLM.from_pretrained(args.repo, trust_remote_code=True) |
|
|
|
|
|
m.to(device) |
|
|
m.eval() |
|
|
|
|
|
|
|
|
if hasattr(m, "config") and m.config is not None: |
|
|
m.config.use_cache = True |
|
|
|
|
|
|
|
|
def encode_prompt(text: str) -> List[int]: |
|
|
ids = text_to_base2_digits(text) |
|
|
ids = wrap_base2_sequence_2(ids, args.bos, args.eos) |
|
|
ids = ids + [int(args.bos)] |
|
|
print("[+] PROMPT IDS = ", ids) |
|
|
return ids |
|
|
|
|
|
prompt_ids = encode_prompt(args.prompt) |
|
|
|
|
|
tokens = torch.tensor([prompt_ids], dtype=torch.long, device=device) |
|
|
generated: List[int] = [] |
|
|
last_text_len = 0 |
|
|
|
|
|
print("\n[Prompt]\n", args.prompt) |
|
|
print(f"\n[Prompt IDs] len={len(prompt_ids)} | BOS={args.bos} EOS={args.eos}") |
|
|
print("\n[Stream]" if args.stream else "\n[Output]") |
|
|
|
|
|
with torch.no_grad(): |
|
|
for _ in range(int(args.max_new_tokens)): |
|
|
|
|
|
out = m(input_ids=tokens, use_cache=True) |
|
|
logits = out.logits[:, -1, :] |
|
|
|
|
|
full_seq = tokens[0].tolist() |
|
|
|
|
|
apply_repetition_penalty_(logits, full_seq, float(args.repetition_penalty)) |
|
|
apply_presence_frequency_penalties_(logits, full_seq, float(args.presence_penalty), float(args.frequency_penalty)) |
|
|
|
|
|
if int(args.no_repeat_ngram_size) > 0: |
|
|
banned = get_banned_tokens_no_repeat_ngram(full_seq, int(args.no_repeat_ngram_size)) |
|
|
mask_banned_tokens_(logits, banned) |
|
|
|
|
|
logits = logits / max(float(args.temperature), 1e-6) |
|
|
|
|
|
if 0 < int(args.top_k) < logits.size(-1): |
|
|
v, _ = torch.topk(logits, int(args.top_k)) |
|
|
logits[logits < v[:, [-1]]] = float("-inf") |
|
|
|
|
|
probs = torch.softmax(logits, dim=-1) |
|
|
next_token = torch.multinomial(probs, 1) |
|
|
tok_id = int(next_token.item()) |
|
|
|
|
|
if tok_id == int(args.eos): |
|
|
break |
|
|
|
|
|
tokens = torch.cat([tokens, next_token], dim=1) |
|
|
generated.append(tok_id) |
|
|
|
|
|
if args.stream: |
|
|
text = decode_base2_digits_strict(generated, encoding=args.decode_encoding, errors=args.decode_errors) |
|
|
if len(text) > last_text_len: |
|
|
sys.stdout.write(text[last_text_len:]) |
|
|
sys.stdout.flush() |
|
|
last_text_len = len(text) |
|
|
|
|
|
if args.stream: |
|
|
print() |
|
|
|
|
|
print("\n[Final Output]\n") |
|
|
print(decode_base2_digits_strict(generated, encoding=args.decode_encoding, errors=args.decode_errors)) |
|
|
|
|
|
if args.print_ids: |
|
|
print("\n[Generated IDs]\n") |
|
|
print(generated) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|