File size: 7,706 Bytes
858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 | """Text generation evaluator."""
from typing import Any, Dict, List, Optional
import torch
import torch.nn as nn
from llm_lab.config import EvalConfig
class GenerationEvaluator:
"""Evaluates text quality by generating from various prompts.
Evaluation perspectives:
1) Grammatical accuracy: Does it generate grammatically correct English sentences?
2) Coherence: Does it maintain context continuity?
3) Diversity: Does it produce different outputs for the same prompt?
4) Repetition avoidance: Does it avoid repeating the same phrases?
5) Knowledge expression: Is knowledge from the training data reflected?
Realistic expectations for a 1B model:
- Generates grammatically correct English sentences β
- Maintains coherence within short paragraphs β
- Complex reasoning or extended logical chains β (requires a larger model)
- Factual accuracy is not guaranteed β οΈ
"""
# Test prompts from various domains
DEFAULT_PROMPTS = [
# ββ General knowledge ββ
"The theory of relativity states that",
"In the history of computer science,",
"The human brain is remarkable because",
# ββ Explanation / Education ββ
"To understand machine learning, one must first",
"The water cycle begins when",
"Photosynthesis is the process by which",
# ββ Narrative / Story ββ
"Once upon a time, in a small village near the mountains,",
"The detective looked at the evidence and realized that",
# ββ Code / Technical ββ
"def fibonacci(n):\n \"\"\"Calculate the nth Fibonacci number.\"\"\"\n",
"The most important data structures in programming are",
# ββ Short completion ββ
"The capital of France is",
"Water boils at a temperature of",
# ββ Long context ββ
("Artificial intelligence has transformed many industries. "
"In healthcare, AI is used for diagnosis and drug discovery. "
"In finance, it powers algorithmic trading and fraud detection. "
"Looking ahead, the most promising application of AI is"),
]
def __init__(self, config: EvalConfig):
self.config = config
@torch.no_grad()
def generate_samples(
self,
model: nn.Module,
tokenizer: Any,
device: torch.device,
prompts: Optional[List[str]] = None,
verbose: bool = True,
) -> List[Dict[str, Any]]:
"""Generates text for each prompt.
Returns:
[{"prompt": str, "generations": [str, ...], "metrics": {...}}, ...]
"""
model.eval()
prompts = prompts or self.DEFAULT_PROMPTS
results = []
if verbose:
print("\n" + "=" * 70)
print("π Text Generation Evaluation")
print("=" * 70)
for idx, prompt in enumerate(prompts):
prompt_results = {
"prompt": prompt,
"generations": [],
"metrics": {},
}
if verbose:
print(f"\n{'β'*60}")
print(f"Prompt [{idx+1}/{len(prompts)}]:")
print(f" \"{prompt[:80]}{'...' if len(prompt) > 80 else ''}\"")
print(f"{'β'*60}")
# Encode prompt
prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
input_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device)
all_texts = []
for sample_idx in range(self.config.num_samples):
# Generate
generated_ids = model.generate(
input_tensor,
max_new_tokens=self.config.max_new_tokens,
temperature=self.config.temperature,
top_k=self.config.top_k,
top_p=self.config.top_p,
)
# Decode (only the part after the prompt)
new_ids = generated_ids[0][len(prompt_ids):].tolist()
generated_text = tokenizer.decode(new_ids)
all_texts.append(generated_text)
prompt_results["generations"].append(generated_text)
if verbose:
print(f"\n βοΈ Generation #{sample_idx+1}:")
# Clean output (including newlines)
display_text = generated_text[:500]
for line in display_text.split("\n"):
print(f" {line}")
if len(generated_text) > 500:
print(f" ... (total {len(generated_text)} characters)")
# Generation quality metrics
prompt_results["metrics"] = self._compute_generation_metrics(all_texts)
if verbose and prompt_results["metrics"]:
m = prompt_results["metrics"]
print(f"\n π Metrics: "
f"avg_length={m['avg_length']:.0f} chars, "
f"repetition_rate={m['repetition_rate']:.1%}, "
f"lexical_diversity={m['lexical_diversity']:.2f}")
results.append(prompt_results)
return results
@staticmethod
def _compute_generation_metrics(texts: List[str]) -> Dict[str, float]:
"""Computes quality metrics for generated text.
Metrics:
- avg_length: Average generation length (characters)
- avg_word_count: Average word count
- repetition_rate: n-gram repetition rate (lower is better)
- lexical_diversity: Ratio of unique words (higher means more diverse)
- sample_diversity: Diversity across samples (how different are different generations)
"""
if not texts:
return {}
# Length
lengths = [len(t) for t in texts]
word_counts = [len(t.split()) for t in texts]
# Repetition rate (based on 4-grams)
rep_rates = []
for text in texts:
words = text.lower().split()
if len(words) < 4:
rep_rates.append(0.0)
continue
ngrams = [tuple(words[i:i+4]) for i in range(len(words)-3)]
unique_ratio = len(set(ngrams)) / len(ngrams) if ngrams else 1.0
rep_rates.append(1.0 - unique_ratio) # repetition rate = 1 - unique ratio
# Lexical diversity (Type-Token Ratio)
diversities = []
for text in texts:
words = text.lower().split()
if words:
diversities.append(len(set(words)) / len(words))
else:
diversities.append(0.0)
# Inter-sample diversity (inverse of Jaccard similarity)
sample_div = 0.0
if len(texts) > 1:
word_sets = [set(t.lower().split()) for t in texts]
similarities = []
for i in range(len(word_sets)):
for j in range(i+1, len(word_sets)):
inter = len(word_sets[i] & word_sets[j])
union = len(word_sets[i] | word_sets[j])
if union > 0:
similarities.append(inter / union)
sample_div = 1.0 - (sum(similarities) / max(len(similarities), 1))
return {
"avg_length": sum(lengths) / len(lengths),
"avg_word_count": sum(word_counts) / len(word_counts),
"repetition_rate": sum(rep_rates) / len(rep_rates),
"lexical_diversity": sum(diversities) / len(diversities),
"sample_diversity": round(sample_div, 3),
}
|