"""Text generation evaluator.""" from typing import Any, Dict, List, Optional import torch import torch.nn as nn from llm_lab.config import EvalConfig class GenerationEvaluator: """Evaluates text quality by generating from various prompts. Evaluation perspectives: 1) Grammatical accuracy: Does it generate grammatically correct English sentences? 2) Coherence: Does it maintain context continuity? 3) Diversity: Does it produce different outputs for the same prompt? 4) Repetition avoidance: Does it avoid repeating the same phrases? 5) Knowledge expression: Is knowledge from the training data reflected? Realistic expectations for a 1B model: - Generates grammatically correct English sentences ✅ - Maintains coherence within short paragraphs ✅ - Complex reasoning or extended logical chains ❌ (requires a larger model) - Factual accuracy is not guaranteed ⚠️ """ # Test prompts from various domains DEFAULT_PROMPTS = [ # ── General knowledge ── "The theory of relativity states that", "In the history of computer science,", "The human brain is remarkable because", # ── Explanation / Education ── "To understand machine learning, one must first", "The water cycle begins when", "Photosynthesis is the process by which", # ── Narrative / Story ── "Once upon a time, in a small village near the mountains,", "The detective looked at the evidence and realized that", # ── Code / Technical ── "def fibonacci(n):\n \"\"\"Calculate the nth Fibonacci number.\"\"\"\n", "The most important data structures in programming are", # ── Short completion ── "The capital of France is", "Water boils at a temperature of", # ── Long context ── ("Artificial intelligence has transformed many industries. " "In healthcare, AI is used for diagnosis and drug discovery. " "In finance, it powers algorithmic trading and fraud detection. " "Looking ahead, the most promising application of AI is"), ] def __init__(self, config: EvalConfig): self.config = config @torch.no_grad() def generate_samples( self, model: nn.Module, tokenizer: Any, device: torch.device, prompts: Optional[List[str]] = None, verbose: bool = True, ) -> List[Dict[str, Any]]: """Generates text for each prompt. Returns: [{"prompt": str, "generations": [str, ...], "metrics": {...}}, ...] """ model.eval() prompts = prompts or self.DEFAULT_PROMPTS results = [] if verbose: print("\n" + "=" * 70) print("📝 Text Generation Evaluation") print("=" * 70) for idx, prompt in enumerate(prompts): prompt_results = { "prompt": prompt, "generations": [], "metrics": {}, } if verbose: print(f"\n{'─'*60}") print(f"Prompt [{idx+1}/{len(prompts)}]:") print(f" \"{prompt[:80]}{'...' if len(prompt) > 80 else ''}\"") print(f"{'─'*60}") # Encode prompt prompt_ids = tokenizer.encode(prompt, add_special_tokens=False) input_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device) all_texts = [] for sample_idx in range(self.config.num_samples): # Generate generated_ids = model.generate( input_tensor, max_new_tokens=self.config.max_new_tokens, temperature=self.config.temperature, top_k=self.config.top_k, top_p=self.config.top_p, ) # Decode (only the part after the prompt) new_ids = generated_ids[0][len(prompt_ids):].tolist() generated_text = tokenizer.decode(new_ids) all_texts.append(generated_text) prompt_results["generations"].append(generated_text) if verbose: print(f"\n ✍️ Generation #{sample_idx+1}:") # Clean output (including newlines) display_text = generated_text[:500] for line in display_text.split("\n"): print(f" {line}") if len(generated_text) > 500: print(f" ... (total {len(generated_text)} characters)") # Generation quality metrics prompt_results["metrics"] = self._compute_generation_metrics(all_texts) if verbose and prompt_results["metrics"]: m = prompt_results["metrics"] print(f"\n 📊 Metrics: " f"avg_length={m['avg_length']:.0f} chars, " f"repetition_rate={m['repetition_rate']:.1%}, " f"lexical_diversity={m['lexical_diversity']:.2f}") results.append(prompt_results) return results @staticmethod def _compute_generation_metrics(texts: List[str]) -> Dict[str, float]: """Computes quality metrics for generated text. Metrics: - avg_length: Average generation length (characters) - avg_word_count: Average word count - repetition_rate: n-gram repetition rate (lower is better) - lexical_diversity: Ratio of unique words (higher means more diverse) - sample_diversity: Diversity across samples (how different are different generations) """ if not texts: return {} # Length lengths = [len(t) for t in texts] word_counts = [len(t.split()) for t in texts] # Repetition rate (based on 4-grams) rep_rates = [] for text in texts: words = text.lower().split() if len(words) < 4: rep_rates.append(0.0) continue ngrams = [tuple(words[i:i+4]) for i in range(len(words)-3)] unique_ratio = len(set(ngrams)) / len(ngrams) if ngrams else 1.0 rep_rates.append(1.0 - unique_ratio) # repetition rate = 1 - unique ratio # Lexical diversity (Type-Token Ratio) diversities = [] for text in texts: words = text.lower().split() if words: diversities.append(len(set(words)) / len(words)) else: diversities.append(0.0) # Inter-sample diversity (inverse of Jaccard similarity) sample_div = 0.0 if len(texts) > 1: word_sets = [set(t.lower().split()) for t in texts] similarities = [] for i in range(len(word_sets)): for j in range(i+1, len(word_sets)): inter = len(word_sets[i] & word_sets[j]) union = len(word_sets[i] | word_sets[j]) if union > 0: similarities.append(inter / union) sample_div = 1.0 - (sum(similarities) / max(len(similarities), 1)) return { "avg_length": sum(lengths) / len(lengths), "avg_word_count": sum(word_counts) / len(word_counts), "repetition_rate": sum(rep_rates) / len(rep_rates), "lexical_diversity": sum(diversities) / len(diversities), "sample_diversity": round(sample_div, 3), }