| from tokenizers import Tokenizer
|
| import torch
|
| import numpy as np
|
| import time
|
| import os
|
| from datetime import datetime
|
|
|
|
|
| def process_string_into_pairs(input_str: str) -> list[str]:
|
| result = []
|
| i = 0
|
| n = len(input_str)
|
|
|
| while i < n:
|
| char = input_str[i]
|
|
|
|
|
| if "a" <= char <= "z":
|
|
|
| if i + 1 < n and "a" <= input_str[i + 1] <= "z":
|
| result.append(char + input_str[i + 1])
|
| i += 2
|
|
|
| elif i + 1 < n and input_str[i + 1] == " ":
|
| result.append(char)
|
| i += 2
|
|
|
| else:
|
| result.append(char)
|
| i += 1
|
|
|
| else:
|
| result.append(char)
|
| i += 1
|
|
|
| return result
|
|
|
|
|
| def get_mask_from_string(input_str: str, tokenizer) -> torch.Tensor:
|
| pairs = process_string_into_pairs(input_str)
|
| masks = [
|
| f"<|mask_{pair}|>" if all(ord(i) < 128 for i in pair) else pair
|
| for pair in pairs
|
| ]
|
| mask_tensor = torch.tensor(
|
| [tokenizer.token_to_id(mask) for mask in masks], dtype=torch.long
|
| )
|
| return mask_tensor
|
|
|
|
|
| def inference(model, input_str: str, tokenizer, device, threshold=0.9):
|
| model.eval()
|
|
|
|
|
| engram_cfg = model.config.engram_config
|
| hash_mapping = None
|
| if engram_cfg is not None:
|
| from modeling_llada_engram import ModelConfig, EngramConfig, NgramHashMapping
|
| from dataclasses import fields
|
|
|
| backbone_config_dict = model.config.to_dict()
|
|
|
| backbone_config = ModelConfig(**{k: v for k, v in backbone_config_dict.items() if k in [f.name for f in fields(ModelConfig)]})
|
|
|
| hash_mapping = NgramHashMapping(
|
| engram_vocab_size = engram_cfg.get('engram_vocab_size', [129280*5, 129280*5]),
|
| max_ngram_size = engram_cfg.get('max_ngram_size', 3),
|
| n_embed_per_ngram = engram_cfg.get('n_embed_per_ngram', 512),
|
| n_head_per_ngram = engram_cfg.get('n_head_per_ngram', 8),
|
| layer_ids = engram_cfg.get('layer_ids', [1, 15]),
|
| pad_id = engram_cfg.get('pad_id', 2),
|
| seed = engram_cfg.get('seed', 0),
|
| config = backbone_config,
|
| )
|
|
|
| with torch.no_grad():
|
| mask_tensor = get_mask_from_string(input_str, tokenizer).unsqueeze(0).to(device)
|
|
|
| is_masked = mask_tensor >= tokenizer.token_to_id("<|mask|>")
|
| rounds = 0
|
| while is_masked.any():
|
| rounds += 1
|
|
|
| output = model(input_ids=mask_tensor)[0]
|
|
|
| output = torch.softmax(output, dim=-1)
|
| unmasked_any = False
|
| prob_info = []
|
|
|
| most_certain_token = (0, 0, 0)
|
|
|
| for i in range(mask_tensor.shape[1]):
|
| if is_masked[0, i]:
|
|
|
| predicted_token = output[0, i].argmax().item()
|
| prob_info.append(
|
| f"{output[0, i, predicted_token].item():.2f} {tokenizer.id_to_token(predicted_token)}"
|
| )
|
| most_certain_token = max(
|
| most_certain_token,
|
| (output[0, i, predicted_token].item(), i, predicted_token)
|
| )
|
|
|
| if output[0, i, predicted_token].item() > threshold:
|
| mask_tensor[0, i] = predicted_token
|
| is_masked[0, i] = False
|
| unmasked_any = True
|
| else:
|
| prob_info.append("")
|
| if not unmasked_any:
|
|
|
| mask_tensor[0, most_certain_token[1]] = most_certain_token[2]
|
| is_masked[0, most_certain_token[1]] = False
|
|
|
| masked_str = "".join(
|
| (
|
| tokenizer.id_to_token(mask_tensor[0, i].item())
|
| if not is_masked[0, i]
|
| else tokenizer.id_to_token(mask_tensor[0, i].item())[7:-2]
|
| )
|
| for i in range(mask_tensor.shape[1])
|
| )
|
| print(masked_str)
|
|
|
|
|
| if __name__ == "__main__":
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| tokenizer = Tokenizer.from_file("tokenizer.json")
|
|
|
|
|
|
|
| try:
|
| from transformers import AutoModelForCausalLM
|
| model = AutoModelForCausalLM.from_pretrained(".", trust_remote_code=True).to(device)
|
| except Exception as e:
|
| print(f"Failed to load with AutoModel: {e}")
|
| print("Falling back to manual loading (if needed, but prefer AutoModel for validation)")
|
|
|
| raise e
|
|
|
|
|
| model = model.to(torch.bfloat16) if device.type == "cuda" else model.float()
|
| print("Loaded model. Parameters:", sum(p.numel() for p in model.parameters()))
|
|
|
| threshold = 0.9
|
|
|
| while True:
|
| input_str = input("Enter a string to process: ")
|
| inference(model, input_str, tokenizer, device, threshold=threshold)
|
| print("")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|