File size: 2,819 Bytes
03cd164 7fd7bd3 03cd164 7fd7bd3 03cd164 7fd7bd3 03cd164 7fd7bd3 03cd164 7fd7bd3 03cd164 7fd7bd3 03cd164 c31993e 7fd7bd3 03cd164 7fd7bd3 03cd164 7fd7bd3 03cd164 7fd7bd3 03cd164 7fd7bd3 03cd164 c31993e 03cd164 7fd7bd3 c31993e 03cd164 c31993e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import torch
from src.models.agiformer import AGIFORMER
import os
import sys
def generate_text(model_path, prompt_text, max_new_tokens=200, temperature=0.8):
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# Config
D_MODEL = 512
N_LAYERS = 6
PATCH_SIZE = 4
print(f"Loading {model_path} (Temp={temperature})...")
model = AGIFORMER(d_model=D_MODEL, n_layers=N_LAYERS, patch_size=PATCH_SIZE).to(DEVICE)
if not os.path.exists(model_path):
print("Model not found.")
return
state_dict = torch.load(model_path, map_location=DEVICE)
model.load_state_dict(state_dict)
model.eval()
# Encode prompt to UTF-8 bytes
input_bytes = list(prompt_text.encode('utf-8'))
pad_len = (PATCH_SIZE - (len(input_bytes) % PATCH_SIZE)) % PATCH_SIZE
if pad_len > 0:
input_bytes.extend([32] * pad_len)
print(f"Prompt: '{prompt_text}'")
print("-" * 50)
print(prompt_text, end='', flush=True)
generated = input_bytes[:]
with torch.no_grad():
for _ in range(max_new_tokens // PATCH_SIZE):
context = generated[-1024:] # Keep context manageable
curr_tensor = torch.tensor(context, dtype=torch.long).unsqueeze(0).to(DEVICE)
# Pass Temperature
pred_patches = model(curr_tensor, temperature=temperature)
last_patch = pred_patches[0, -1, :].cpu().tolist()
generated.extend(last_patch)
# Real-time decoding for display is tricky with multi-byte chars
# We'll just collect and decode at the end or try best effort
pass
print("\n" + "-" * 50)
try:
full_text = bytes(generated).decode('utf-8', errors='replace')
# Print only the new part
print(full_text[len(prompt_text):])
except:
print("\n[Decoding Error]")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Generate text with AGIFORMER')
parser.add_argument('--prompt', type=str, default="The history of ", help='Text prompt to start generation')
parser.add_argument('--temp', type=float, default=0.7, help='Sampling temperature')
parser.add_argument('--model', type=str, default="best_model.pth", help='Path to model checkpoint')
args = parser.parse_args()
# Check if user meant to use the Turkish model but it's named differently
model_path = args.model
if not os.path.exists(model_path) and os.path.exists("best_model_turkish.pth"):
print(f"Note: '{model_path}' not found, using 'best_model_turkish.pth' instead.")
model_path = "best_model_turkish.pth"
generate_text(model_path, args.prompt, temperature=args.temp)
|