agiformer / generate.py
tefoteknik's picture
Update AGIFORMER with Turkish benchmark
c31993e verified
import torch
from src.models.agiformer import AGIFORMER
import os
import sys
def generate_text(model_path, prompt_text, max_new_tokens=200, temperature=0.8):
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# Config
D_MODEL = 512
N_LAYERS = 6
PATCH_SIZE = 4
print(f"Loading {model_path} (Temp={temperature})...")
model = AGIFORMER(d_model=D_MODEL, n_layers=N_LAYERS, patch_size=PATCH_SIZE).to(DEVICE)
if not os.path.exists(model_path):
print("Model not found.")
return
state_dict = torch.load(model_path, map_location=DEVICE)
model.load_state_dict(state_dict)
model.eval()
# Encode prompt to UTF-8 bytes
input_bytes = list(prompt_text.encode('utf-8'))
pad_len = (PATCH_SIZE - (len(input_bytes) % PATCH_SIZE)) % PATCH_SIZE
if pad_len > 0:
input_bytes.extend([32] * pad_len)
print(f"Prompt: '{prompt_text}'")
print("-" * 50)
print(prompt_text, end='', flush=True)
generated = input_bytes[:]
with torch.no_grad():
for _ in range(max_new_tokens // PATCH_SIZE):
context = generated[-1024:] # Keep context manageable
curr_tensor = torch.tensor(context, dtype=torch.long).unsqueeze(0).to(DEVICE)
# Pass Temperature
pred_patches = model(curr_tensor, temperature=temperature)
last_patch = pred_patches[0, -1, :].cpu().tolist()
generated.extend(last_patch)
# Real-time decoding for display is tricky with multi-byte chars
# We'll just collect and decode at the end or try best effort
pass
print("\n" + "-" * 50)
try:
full_text = bytes(generated).decode('utf-8', errors='replace')
# Print only the new part
print(full_text[len(prompt_text):])
except:
print("\n[Decoding Error]")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Generate text with AGIFORMER')
parser.add_argument('--prompt', type=str, default="The history of ", help='Text prompt to start generation')
parser.add_argument('--temp', type=float, default=0.7, help='Sampling temperature')
parser.add_argument('--model', type=str, default="best_model.pth", help='Path to model checkpoint')
args = parser.parse_args()
# Check if user meant to use the Turkish model but it's named differently
model_path = args.model
if not os.path.exists(model_path) and os.path.exists("best_model_turkish.pth"):
print(f"Note: '{model_path}' not found, using 'best_model_turkish.pth' instead.")
model_path = "best_model_turkish.pth"
generate_text(model_path, args.prompt, temperature=args.temp)