|
|
"""
|
|
|
Advanced Text Generation Utilities
|
|
|
For inference, sampling, and interactive generation
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
from typing import Dict, List, Optional, Union, Callable
|
|
|
import time
|
|
|
import json
|
|
|
from dataclasses import dataclass
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class GenerationConfig:
|
|
|
"""Configuration for text generation"""
|
|
|
max_new_tokens: int = 512
|
|
|
temperature: float = 0.8
|
|
|
top_k: int = 50
|
|
|
top_p: float = 0.9
|
|
|
repetition_penalty: float = 1.0
|
|
|
length_penalty: float = 1.0
|
|
|
do_sample: bool = True
|
|
|
num_beams: int = 1
|
|
|
early_stopping: bool = True
|
|
|
pad_token_id: Optional[int] = None
|
|
|
eos_token_id: Optional[int] = None
|
|
|
min_length: int = 0
|
|
|
no_repeat_ngram_size: int = 0
|
|
|
|
|
|
|
|
|
class AdvancedGenerator:
|
|
|
"""Advanced text generation with multiple sampling strategies"""
|
|
|
|
|
|
def __init__(self, model, tokenizer, device):
|
|
|
self.model = model
|
|
|
self.tokenizer = tokenizer
|
|
|
self.device = device
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def generate(
|
|
|
self,
|
|
|
prompt: str,
|
|
|
config: GenerationConfig = None,
|
|
|
stream: bool = False,
|
|
|
callback: Optional[Callable[[str], None]] = None
|
|
|
) -> Union[str, List[str]]:
|
|
|
"""Generate text with advanced sampling"""
|
|
|
if config is None:
|
|
|
config = GenerationConfig()
|
|
|
|
|
|
self.model.eval()
|
|
|
|
|
|
|
|
|
input_ids = torch.tensor([self.tokenizer.encode(prompt)], device=self.device)
|
|
|
original_length = input_ids.size(1)
|
|
|
|
|
|
|
|
|
generated = input_ids.clone()
|
|
|
generated_tokens = []
|
|
|
|
|
|
for step in range(config.max_new_tokens):
|
|
|
|
|
|
outputs = self.model(input_ids=generated, use_cache=False)
|
|
|
logits = outputs['logits']
|
|
|
|
|
|
|
|
|
next_token_logits = logits[0, -1, :]
|
|
|
|
|
|
|
|
|
if config.repetition_penalty != 1.0:
|
|
|
next_token_logits = self._apply_repetition_penalty(
|
|
|
next_token_logits, generated[0], config.repetition_penalty
|
|
|
)
|
|
|
|
|
|
|
|
|
if config.temperature != 1.0:
|
|
|
next_token_logits = next_token_logits / config.temperature
|
|
|
|
|
|
|
|
|
if config.top_k > 0:
|
|
|
next_token_logits = self._top_k_filtering(next_token_logits, config.top_k)
|
|
|
|
|
|
|
|
|
if config.top_p < 1.0:
|
|
|
next_token_logits = self._top_p_filtering(next_token_logits, config.top_p)
|
|
|
|
|
|
|
|
|
if config.do_sample:
|
|
|
probs = F.softmax(next_token_logits, dim=-1)
|
|
|
next_token = torch.multinomial(probs, num_samples=1)
|
|
|
else:
|
|
|
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
|
|
|
|
|
|
|
|
|
if config.eos_token_id is not None and next_token.item() == config.eos_token_id:
|
|
|
break
|
|
|
|
|
|
|
|
|
generated = torch.cat([generated, next_token.unsqueeze(0)], dim=-1)
|
|
|
generated_tokens.append(next_token.item())
|
|
|
|
|
|
|
|
|
if stream and callback:
|
|
|
partial_text = self.tokenizer.decode(generated_tokens)
|
|
|
callback(partial_text)
|
|
|
|
|
|
|
|
|
if step < config.min_length:
|
|
|
continue
|
|
|
|
|
|
|
|
|
if config.early_stopping and self._should_stop_early(generated_tokens, config):
|
|
|
break
|
|
|
|
|
|
|
|
|
full_text = self.tokenizer.decode(generated[0].cpu().tolist())
|
|
|
generated_text = full_text[len(prompt):]
|
|
|
|
|
|
return generated_text
|
|
|
|
|
|
def _apply_repetition_penalty(self, logits, input_ids, penalty):
|
|
|
"""Apply repetition penalty to logits"""
|
|
|
score = torch.gather(logits, 0, input_ids)
|
|
|
score = torch.where(score < 0, score * penalty, score / penalty)
|
|
|
logits.scatter_(0, input_ids, score)
|
|
|
return logits
|
|
|
|
|
|
def _top_k_filtering(self, logits, top_k):
|
|
|
"""Apply top-k filtering"""
|
|
|
if top_k > 0:
|
|
|
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
|
|
logits[indices_to_remove] = float('-inf')
|
|
|
return logits
|
|
|
|
|
|
def _top_p_filtering(self, logits, top_p):
|
|
|
"""Apply top-p (nucleus) filtering"""
|
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
|
|
|
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p
|
|
|
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
|
|
|
sorted_indices_to_remove[0] = 0
|
|
|
|
|
|
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
|
|
logits[indices_to_remove] = float('-inf')
|
|
|
return logits
|
|
|
|
|
|
def _should_stop_early(self, tokens, config):
|
|
|
"""Check if generation should stop early"""
|
|
|
if len(tokens) < 3:
|
|
|
return False
|
|
|
|
|
|
|
|
|
if config.no_repeat_ngram_size > 0:
|
|
|
ngram_size = config.no_repeat_ngram_size
|
|
|
if len(tokens) >= ngram_size * 2:
|
|
|
recent_ngram = tokens[-ngram_size:]
|
|
|
prev_ngram = tokens[-ngram_size*2:-ngram_size]
|
|
|
if recent_ngram == prev_ngram:
|
|
|
return True
|
|
|
|
|
|
return False
|
|
|
|
|
|
def batch_generate(
|
|
|
self,
|
|
|
prompts: List[str],
|
|
|
config: GenerationConfig = None,
|
|
|
batch_size: int = 4
|
|
|
) -> List[str]:
|
|
|
"""Generate text for multiple prompts in batches"""
|
|
|
if config is None:
|
|
|
config = GenerationConfig()
|
|
|
|
|
|
results = []
|
|
|
|
|
|
for i in range(0, len(prompts), batch_size):
|
|
|
batch_prompts = prompts[i:i + batch_size]
|
|
|
batch_results = []
|
|
|
|
|
|
for prompt in batch_prompts:
|
|
|
result = self.generate(prompt, config)
|
|
|
batch_results.append(result)
|
|
|
|
|
|
results.extend(batch_results)
|
|
|
|
|
|
return results
|
|
|
|
|
|
def interactive_chat(self, system_prompt: str = "", config: GenerationConfig = None):
|
|
|
"""Interactive chat interface"""
|
|
|
if config is None:
|
|
|
config = GenerationConfig(max_new_tokens=256, temperature=0.7)
|
|
|
|
|
|
print("=== Interactive Chat ===")
|
|
|
print("Type 'quit' to exit, 'clear' to clear history")
|
|
|
print("=" * 30)
|
|
|
|
|
|
conversation_history = system_prompt
|
|
|
|
|
|
while True:
|
|
|
user_input = input("\nYou: ").strip()
|
|
|
|
|
|
if user_input.lower() == 'quit':
|
|
|
break
|
|
|
elif user_input.lower() == 'clear':
|
|
|
conversation_history = system_prompt
|
|
|
print("History cleared.")
|
|
|
continue
|
|
|
|
|
|
|
|
|
conversation_history += f"\nUser: {user_input}\nAssistant: "
|
|
|
|
|
|
|
|
|
print("Assistant: ", end="", flush=True)
|
|
|
|
|
|
def stream_callback(text):
|
|
|
print(text, end="", flush=True)
|
|
|
|
|
|
response = self.generate(
|
|
|
conversation_history,
|
|
|
config,
|
|
|
stream=True,
|
|
|
callback=stream_callback
|
|
|
)
|
|
|
|
|
|
|
|
|
conversation_history += response
|
|
|
print()
|
|
|
|
|
|
|
|
|
class ControllableGenerator:
|
|
|
"""Generator with controllable attributes (sentiment, style, etc.)"""
|
|
|
|
|
|
def __init__(self, model, tokenizer, device):
|
|
|
self.model = model
|
|
|
self.tokenizer = tokenizer
|
|
|
self.device = device
|
|
|
self.base_generator = AdvancedGenerator(model, tokenizer, device)
|
|
|
|
|
|
def generate_with_style(
|
|
|
self,
|
|
|
prompt: str,
|
|
|
style: str = "neutral",
|
|
|
config: GenerationConfig = None
|
|
|
) -> str:
|
|
|
"""Generate text with specific style"""
|
|
|
style_prompts = {
|
|
|
"formal": "Write in a formal, professional tone: ",
|
|
|
"casual": "Write in a casual, friendly tone: ",
|
|
|
"creative": "Write creatively and imaginatively: ",
|
|
|
"technical": "Write in a technical, precise manner: ",
|
|
|
"humorous": "Write with humor and wit: ",
|
|
|
"serious": "Write in a serious, thoughtful tone: "
|
|
|
}
|
|
|
|
|
|
style_prefix = style_prompts.get(style, "")
|
|
|
full_prompt = style_prefix + prompt
|
|
|
|
|
|
return self.base_generator.generate(full_prompt, config)
|
|
|
|
|
|
def generate_with_length_control(
|
|
|
self,
|
|
|
prompt: str,
|
|
|
target_length: str = "medium",
|
|
|
config: GenerationConfig = None
|
|
|
) -> str:
|
|
|
"""Generate text with controlled length"""
|
|
|
if config is None:
|
|
|
config = GenerationConfig()
|
|
|
|
|
|
length_configs = {
|
|
|
"short": GenerationConfig(max_new_tokens=50, **config.__dict__),
|
|
|
"medium": GenerationConfig(max_new_tokens=150, **config.__dict__),
|
|
|
"long": GenerationConfig(max_new_tokens=400, **config.__dict__)
|
|
|
}
|
|
|
|
|
|
target_config = length_configs.get(target_length, config)
|
|
|
return self.base_generator.generate(prompt, target_config)
|
|
|
|
|
|
def generate_with_constraints(
|
|
|
self,
|
|
|
prompt: str,
|
|
|
must_include: List[str] = None,
|
|
|
must_avoid: List[str] = None,
|
|
|
config: GenerationConfig = None
|
|
|
) -> str:
|
|
|
"""Generate text with inclusion/exclusion constraints"""
|
|
|
if config is None:
|
|
|
config = GenerationConfig()
|
|
|
|
|
|
|
|
|
constraint_prompt = prompt
|
|
|
|
|
|
if must_include:
|
|
|
constraint_prompt += f" (Must include: {', '.join(must_include)})"
|
|
|
|
|
|
if must_avoid:
|
|
|
constraint_prompt += f" (Avoid mentioning: {', '.join(must_avoid)})"
|
|
|
|
|
|
return self.base_generator.generate(constraint_prompt, config)
|
|
|
|
|
|
|
|
|
class BenchmarkGenerator:
|
|
|
"""Generator for running generation benchmarks"""
|
|
|
|
|
|
def __init__(self, model, tokenizer, device):
|
|
|
self.generator = AdvancedGenerator(model, tokenizer, device)
|
|
|
|
|
|
def speed_benchmark(
|
|
|
self,
|
|
|
prompts: List[str],
|
|
|
config: GenerationConfig = None,
|
|
|
num_runs: int = 3
|
|
|
) -> Dict[str, float]:
|
|
|
"""Benchmark generation speed"""
|
|
|
if config is None:
|
|
|
config = GenerationConfig(max_new_tokens=100)
|
|
|
|
|
|
times = []
|
|
|
total_tokens = 0
|
|
|
|
|
|
for run in range(num_runs):
|
|
|
start_time = time.time()
|
|
|
|
|
|
for prompt in prompts:
|
|
|
result = self.generator.generate(prompt, config)
|
|
|
total_tokens += len(self.generator.tokenizer.encode(result))
|
|
|
|
|
|
end_time = time.time()
|
|
|
times.append(end_time - start_time)
|
|
|
|
|
|
avg_time = sum(times) / len(times)
|
|
|
tokens_per_second = (total_tokens / num_runs) / avg_time
|
|
|
|
|
|
return {
|
|
|
"avg_time_per_batch": avg_time,
|
|
|
"tokens_per_second": tokens_per_second,
|
|
|
"total_tokens_generated": total_tokens // num_runs,
|
|
|
"num_prompts": len(prompts)
|
|
|
}
|
|
|
|
|
|
def quality_benchmark(
|
|
|
self,
|
|
|
prompts: List[str],
|
|
|
configs: List[GenerationConfig]
|
|
|
) -> Dict[str, List[str]]:
|
|
|
"""Compare generation quality across different configs"""
|
|
|
results = {}
|
|
|
|
|
|
for i, config in enumerate(configs):
|
|
|
config_name = f"config_{i}"
|
|
|
results[config_name] = []
|
|
|
|
|
|
for prompt in prompts:
|
|
|
result = self.generator.generate(prompt, config)
|
|
|
results[config_name].append(result)
|
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
def create_generation_configs() -> Dict[str, GenerationConfig]:
|
|
|
"""Create predefined generation configurations"""
|
|
|
return {
|
|
|
"creative": GenerationConfig(
|
|
|
temperature=0.9,
|
|
|
top_k=40,
|
|
|
top_p=0.9,
|
|
|
repetition_penalty=1.1
|
|
|
),
|
|
|
"balanced": GenerationConfig(
|
|
|
temperature=0.7,
|
|
|
top_k=50,
|
|
|
top_p=0.9,
|
|
|
repetition_penalty=1.0
|
|
|
),
|
|
|
"focused": GenerationConfig(
|
|
|
temperature=0.3,
|
|
|
top_k=20,
|
|
|
top_p=0.8,
|
|
|
repetition_penalty=1.0
|
|
|
),
|
|
|
"deterministic": GenerationConfig(
|
|
|
temperature=0.0,
|
|
|
do_sample=False,
|
|
|
repetition_penalty=1.0
|
|
|
),
|
|
|
"diverse": GenerationConfig(
|
|
|
temperature=1.0,
|
|
|
top_k=100,
|
|
|
top_p=0.95,
|
|
|
repetition_penalty=1.2
|
|
|
)
|
|
|
}
|
|
|
|
|
|
|
|
|
def demo_generation(model, tokenizer, device):
|
|
|
"""Demonstrate various generation capabilities"""
|
|
|
generator = AdvancedGenerator(model, tokenizer, device)
|
|
|
controllable = ControllableGenerator(model, tokenizer, device)
|
|
|
configs = create_generation_configs()
|
|
|
|
|
|
print("=== Generation Demo ===")
|
|
|
|
|
|
|
|
|
prompt = "The future of artificial intelligence"
|
|
|
print(f"\nPrompt: {prompt}")
|
|
|
|
|
|
for name, config in configs.items():
|
|
|
print(f"\n{name.upper()} generation:")
|
|
|
result = generator.generate(prompt, config)
|
|
|
print(result[:200] + "..." if len(result) > 200 else result)
|
|
|
|
|
|
|
|
|
print("\n=== Style-Controlled Generation ===")
|
|
|
styles = ["formal", "casual", "creative", "technical"]
|
|
|
|
|
|
for style in styles:
|
|
|
print(f"\n{style.upper()} style:")
|
|
|
result = controllable.generate_with_style(prompt, style, configs["balanced"])
|
|
|
print(result[:150] + "..." if len(result) > 150 else result)
|
|
|
|
|
|
|
|
|
print("\n=== Length-Controlled Generation ===")
|
|
|
lengths = ["short", "medium", "long"]
|
|
|
|
|
|
for length in lengths:
|
|
|
print(f"\n{length.upper()} length:")
|
|
|
result = controllable.generate_with_length_control(prompt, length)
|
|
|
print(f"({len(result)} chars) {result}")
|
|
|
|
|
|
print("\n=== Demo Complete ===")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
print("Generation utilities loaded successfully!")
|
|
|
print("Use demo_generation(model, tokenizer, device) to see examples.")
|
|
|
|