| """Text generation evaluator.""" |
|
|
| from typing import Any, Dict, List, Optional |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from llm_lab.config import EvalConfig |
|
|
|
|
| class GenerationEvaluator: |
| """Evaluates text quality by generating from various prompts. |
| |
| Evaluation perspectives: |
| 1) Grammatical accuracy: Does it generate grammatically correct English sentences? |
| 2) Coherence: Does it maintain context continuity? |
| 3) Diversity: Does it produce different outputs for the same prompt? |
| 4) Repetition avoidance: Does it avoid repeating the same phrases? |
| 5) Knowledge expression: Is knowledge from the training data reflected? |
| |
| Realistic expectations for a 1B model: |
| - Generates grammatically correct English sentences β
|
| - Maintains coherence within short paragraphs β
|
| - Complex reasoning or extended logical chains β (requires a larger model) |
| - Factual accuracy is not guaranteed β οΈ |
| """ |
|
|
| |
| DEFAULT_PROMPTS = [ |
| |
| "The theory of relativity states that", |
| "In the history of computer science,", |
| "The human brain is remarkable because", |
|
|
| |
| "To understand machine learning, one must first", |
| "The water cycle begins when", |
| "Photosynthesis is the process by which", |
|
|
| |
| "Once upon a time, in a small village near the mountains,", |
| "The detective looked at the evidence and realized that", |
|
|
| |
| "def fibonacci(n):\n \"\"\"Calculate the nth Fibonacci number.\"\"\"\n", |
| "The most important data structures in programming are", |
|
|
| |
| "The capital of France is", |
| "Water boils at a temperature of", |
|
|
| |
| ("Artificial intelligence has transformed many industries. " |
| "In healthcare, AI is used for diagnosis and drug discovery. " |
| "In finance, it powers algorithmic trading and fraud detection. " |
| "Looking ahead, the most promising application of AI is"), |
| ] |
|
|
| def __init__(self, config: EvalConfig): |
| self.config = config |
|
|
| @torch.no_grad() |
| def generate_samples( |
| self, |
| model: nn.Module, |
| tokenizer: Any, |
| device: torch.device, |
| prompts: Optional[List[str]] = None, |
| verbose: bool = True, |
| ) -> List[Dict[str, Any]]: |
| """Generates text for each prompt. |
| |
| Returns: |
| [{"prompt": str, "generations": [str, ...], "metrics": {...}}, ...] |
| """ |
| model.eval() |
| prompts = prompts or self.DEFAULT_PROMPTS |
| results = [] |
|
|
| if verbose: |
| print("\n" + "=" * 70) |
| print("π Text Generation Evaluation") |
| print("=" * 70) |
|
|
| for idx, prompt in enumerate(prompts): |
| prompt_results = { |
| "prompt": prompt, |
| "generations": [], |
| "metrics": {}, |
| } |
|
|
| if verbose: |
| print(f"\n{'β'*60}") |
| print(f"Prompt [{idx+1}/{len(prompts)}]:") |
| print(f" \"{prompt[:80]}{'...' if len(prompt) > 80 else ''}\"") |
| print(f"{'β'*60}") |
|
|
| |
| prompt_ids = tokenizer.encode(prompt, add_special_tokens=False) |
| input_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device) |
|
|
| all_texts = [] |
| for sample_idx in range(self.config.num_samples): |
| |
| generated_ids = model.generate( |
| input_tensor, |
| max_new_tokens=self.config.max_new_tokens, |
| temperature=self.config.temperature, |
| top_k=self.config.top_k, |
| top_p=self.config.top_p, |
| ) |
|
|
| |
| new_ids = generated_ids[0][len(prompt_ids):].tolist() |
| generated_text = tokenizer.decode(new_ids) |
| all_texts.append(generated_text) |
|
|
| prompt_results["generations"].append(generated_text) |
|
|
| if verbose: |
| print(f"\n βοΈ Generation #{sample_idx+1}:") |
| |
| display_text = generated_text[:500] |
| for line in display_text.split("\n"): |
| print(f" {line}") |
| if len(generated_text) > 500: |
| print(f" ... (total {len(generated_text)} characters)") |
|
|
| |
| prompt_results["metrics"] = self._compute_generation_metrics(all_texts) |
|
|
| if verbose and prompt_results["metrics"]: |
| m = prompt_results["metrics"] |
| print(f"\n π Metrics: " |
| f"avg_length={m['avg_length']:.0f} chars, " |
| f"repetition_rate={m['repetition_rate']:.1%}, " |
| f"lexical_diversity={m['lexical_diversity']:.2f}") |
|
|
| results.append(prompt_results) |
|
|
| return results |
|
|
| @staticmethod |
| def _compute_generation_metrics(texts: List[str]) -> Dict[str, float]: |
| """Computes quality metrics for generated text. |
| |
| Metrics: |
| - avg_length: Average generation length (characters) |
| - avg_word_count: Average word count |
| - repetition_rate: n-gram repetition rate (lower is better) |
| - lexical_diversity: Ratio of unique words (higher means more diverse) |
| - sample_diversity: Diversity across samples (how different are different generations) |
| """ |
| if not texts: |
| return {} |
|
|
| |
| lengths = [len(t) for t in texts] |
| word_counts = [len(t.split()) for t in texts] |
|
|
| |
| rep_rates = [] |
| for text in texts: |
| words = text.lower().split() |
| if len(words) < 4: |
| rep_rates.append(0.0) |
| continue |
| ngrams = [tuple(words[i:i+4]) for i in range(len(words)-3)] |
| unique_ratio = len(set(ngrams)) / len(ngrams) if ngrams else 1.0 |
| rep_rates.append(1.0 - unique_ratio) |
|
|
| |
| diversities = [] |
| for text in texts: |
| words = text.lower().split() |
| if words: |
| diversities.append(len(set(words)) / len(words)) |
| else: |
| diversities.append(0.0) |
|
|
| |
| sample_div = 0.0 |
| if len(texts) > 1: |
| word_sets = [set(t.lower().split()) for t in texts] |
| similarities = [] |
| for i in range(len(word_sets)): |
| for j in range(i+1, len(word_sets)): |
| inter = len(word_sets[i] & word_sets[j]) |
| union = len(word_sets[i] | word_sets[j]) |
| if union > 0: |
| similarities.append(inter / union) |
| sample_div = 1.0 - (sum(similarities) / max(len(similarities), 1)) |
|
|
| return { |
| "avg_length": sum(lengths) / len(lengths), |
| "avg_word_count": sum(word_counts) / len(word_counts), |
| "repetition_rate": sum(rep_rates) / len(rep_rates), |
| "lexical_diversity": sum(diversities) / len(diversities), |
| "sample_diversity": round(sample_div, 3), |
| } |
|
|