tinyvic / generate.py
Viclim's picture
Upload 17 files
9299fff verified
"""
VicAI Text Generation
Interactive text generation and sampling utilities.
"""
import argparse
import sys
import torch
from model import VicAIModel, VicAIConfig, create_vicai_5b
from tokenizer import ByteLevelBPETokenizer, BPETokenizer
from utils import get_logger
def generate_interactive(
model,
tokenizer,
device,
max_new_tokens: int = 256,
temperature: float = 0.8,
top_k: int = 50,
top_p: float = 0.9,
repetition_penalty: float = 1.1,
):
"""Interactive text generation loop."""
print("\n" + "=" * 60)
print("VicAI Interactive Generation")
print("=" * 60)
print("Commands:")
print(" /quit - Exit the program")
print(" /config - Show current generation settings")
print(" /temp X - Set temperature (0.1 - 2.0)")
print(" /topk X - Set top-k (1 - 100)")
print(" /topp X - Set top-p (0.0 - 1.0)")
print(" /reppen X - Set repetition penalty (1.0 - 2.0)")
print(" /maxlen X - Set max new tokens")
print("=" * 60 + "\n")
# Current settings
settings = {
'temperature': temperature,
'top_k': top_k,
'top_p': top_p,
'repetition_penalty': repetition_penalty,
'max_new_tokens': max_new_tokens,
}
while True:
try:
# Get prompt
prompt = input("\nPrompt: ").strip()
# Handle commands
if prompt == '/quit':
print("Goodbye!")
break
if prompt == '/config':
print("\nCurrent settings:")
for key, value in settings.items():
print(f" {key}: {value}")
continue
if prompt.startswith('/temp '):
try:
settings['temperature'] = float(prompt.split()[1])
print(f"Temperature set to {settings['temperature']}")
except (ValueError, IndexError):
print("Invalid temperature value")
continue
if prompt.startswith('/topk '):
try:
settings['top_k'] = int(prompt.split()[1])
print(f"Top-k set to {settings['top_k']}")
except (ValueError, IndexError):
print("Invalid top-k value")
continue
if prompt.startswith('/topp '):
try:
settings['top_p'] = float(prompt.split()[1])
print(f"Top-p set to {settings['top_p']}")
except (ValueError, IndexError):
print("Invalid top-p value")
continue
if prompt.startswith('/reppen '):
try:
settings['repetition_penalty'] = float(prompt.split()[1])
print(f"Repetition penalty set to {settings['repetition_penalty']}")
except (ValueError, IndexError):
print("Invalid repetition penalty value")
continue
if prompt.startswith('/maxlen '):
try:
settings['max_new_tokens'] = int(prompt.split()[1])
print(f"Max new tokens set to {settings['max_new_tokens']}")
except (ValueError, IndexError):
print("Invalid max new tokens value")
continue
if not prompt:
continue
# Encode prompt
input_ids = torch.tensor([tokenizer.encode(prompt)], device=device)
# Generate
print("\nGenerating...")
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_new_tokens=settings['max_new_tokens'],
temperature=settings['temperature'],
top_k=settings['top_k'],
top_p=settings['top_p'],
repetition_penalty=settings['repetition_penalty'],
eos_token_id=tokenizer.eos_token_id,
)
# Decode and print
generated_text = tokenizer.decode(output_ids[0].tolist())
# Remove the original prompt from output
prompt_text = tokenizer.decode(input_ids[0].tolist())
if generated_text.startswith(prompt_text):
generated_text = generated_text[len(prompt_text):].strip()
print("\n" + "-" * 60)
print("Generated:")
print("-" * 60)
print(generated_text)
print("-" * 60)
# Print token info
num_tokens = output_ids.shape[1] - input_ids.shape[1]
print(f"\nTokens generated: {num_tokens}")
except KeyboardInterrupt:
print("\n\nInterrupted by user. Type /quit to exit.")
except Exception as e:
print(f"\nError: {e}")
def generate_batch(
model,
tokenizer,
prompts: list,
device,
max_new_tokens: int = 256,
temperature: float = 0.8,
top_k: int = 50,
top_p: float = 0.9,
):
"""Generate completions for multiple prompts."""
results = []
for prompt in prompts:
input_ids = torch.tensor([tokenizer.encode(prompt)], device=device)
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
eos_token_id=tokenizer.eos_token_id,
)
generated_text = tokenizer.decode(output_ids[0].tolist())
prompt_text = tokenizer.decode(input_ids[0].tolist())
if generated_text.startswith(prompt_text):
generated_text = generated_text[len(prompt_text):].strip()
results.append({
'prompt': prompt,
'completion': generated_text,
})
return results
def benchmark_generation(
model,
tokenizer,
device,
num_runs: int = 10,
max_new_tokens: int = 128,
prompt: str = "The future of artificial intelligence is",
):
"""Benchmark generation speed."""
import time
print(f"\nBenchmarking generation ({num_runs} runs)...")
input_ids = torch.tensor([tokenizer.encode(prompt)], device=device)
# Warmup
with torch.no_grad():
_ = model.generate(input_ids, max_new_tokens=10)
torch.cuda.synchronize()
# Benchmark
times = []
tokens_generated = []
for i in range(num_runs):
start = time.time()
with torch.no_grad():
output = model.generate(
input_ids,
max_new_tokens=max_new_tokens,
temperature=1.0,
)
torch.cuda.synchronize()
elapsed = time.time() - start
num_tokens = output.shape[1] - input_ids.shape[1]
times.append(elapsed)
tokens_generated.append(num_tokens)
print(f" Run {i+1}: {num_tokens} tokens in {elapsed:.2f}s ({num_tokens/elapsed:.1f} tok/s)")
avg_time = sum(times) / len(times)
avg_tokens = sum(tokens_generated) / len(tokens_generated)
avg_speed = avg_tokens / avg_time
print(f"\nAverage: {avg_tokens:.1f} tokens in {avg_time:.2f}s ({avg_speed:.1f} tok/s)")
def main():
parser = argparse.ArgumentParser(description='Generate text with VicAI')
parser.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint')
parser.add_argument('--tokenizer', type=str, default='tokenizer.pkl', help='Path to tokenizer')
parser.add_argument('--prompt', type=str, default=None, help='Single prompt to generate from')
parser.add_argument('--interactive', action='store_true', help='Interactive mode')
parser.add_argument('--max-new-tokens', type=int, default=256, help='Maximum tokens to generate')
parser.add_argument('--temperature', type=float, default=0.8, help='Sampling temperature')
parser.add_argument('--top-k', type=int, default=50, help='Top-k sampling')
parser.add_argument('--top-p', type=float, default=0.9, help='Top-p (nucleus) sampling')
parser.add_argument('--repetition-penalty', type=float, default=1.1, help='Repetition penalty')
parser.add_argument('--benchmark', action='store_true', help='Run generation benchmark')
parser.add_argument('--device', type=str, default='cuda', help='Device to use')
args = parser.parse_args()
# Setup device
device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Load tokenizer
print(f"Loading tokenizer from {args.tokenizer}...")
# Use ByteLevelBPETokenizer by default (our trained tokenizer)
tokenizer = ByteLevelBPETokenizer()
tokenizer.load(args.tokenizer)
print(f"Tokenizer loaded: {len(tokenizer)} tokens")
# Load model
print(f"Loading model from {args.checkpoint}...")
checkpoint = torch.load(args.checkpoint, map_location=device)
# Create model (assuming 5B config)
model = create_vicai_5b(vocab_size=len(tokenizer))
# Load weights
state_dict = checkpoint.get('model', checkpoint)
model.load_state_dict(state_dict)
model = model.to(device)
model.eval()
print(f"Model loaded: ~{model.get_num_params() / 1e9:.2f}B parameters")
# Run benchmark if requested
if args.benchmark:
benchmark_generation(model, tokenizer, device)
return
# Interactive mode
if args.interactive or args.prompt is None:
generate_interactive(
model,
tokenizer,
device,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
repetition_penalty=args.repetition_penalty,
)
else:
# Single prompt generation
print(f"\nPrompt: {args.prompt}")
print("-" * 60)
input_ids = torch.tensor([tokenizer.encode(args.prompt)], device=device)
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
repetition_penalty=args.repetition_penalty,
eos_token_id=tokenizer.eos_token_id,
)
generated_text = tokenizer.decode(output_ids[0].tolist())
prompt_text = tokenizer.decode(input_ids[0].tolist())
if generated_text.startswith(prompt_text):
generated_text = generated_text[len(prompt_text):].strip()
print(generated_text)
print("-" * 60)
num_tokens = output_ids.shape[1] - input_ids.shape[1]
print(f"\nGenerated {num_tokens} tokens")
if __name__ == '__main__':
main()