LLM-1B-Lab / llm_lab /evaluation /generation.py
Vjeong's picture
docs: translate all Korean comments and docstrings to English
858e8b2
"""Text generation evaluator."""
from typing import Any, Dict, List, Optional
import torch
import torch.nn as nn
from llm_lab.config import EvalConfig
class GenerationEvaluator:
"""Evaluates text quality by generating from various prompts.
Evaluation perspectives:
1) Grammatical accuracy: Does it generate grammatically correct English sentences?
2) Coherence: Does it maintain context continuity?
3) Diversity: Does it produce different outputs for the same prompt?
4) Repetition avoidance: Does it avoid repeating the same phrases?
5) Knowledge expression: Is knowledge from the training data reflected?
Realistic expectations for a 1B model:
- Generates grammatically correct English sentences βœ…
- Maintains coherence within short paragraphs βœ…
- Complex reasoning or extended logical chains ❌ (requires a larger model)
- Factual accuracy is not guaranteed ⚠️
"""
# Test prompts from various domains
DEFAULT_PROMPTS = [
# ── General knowledge ──
"The theory of relativity states that",
"In the history of computer science,",
"The human brain is remarkable because",
# ── Explanation / Education ──
"To understand machine learning, one must first",
"The water cycle begins when",
"Photosynthesis is the process by which",
# ── Narrative / Story ──
"Once upon a time, in a small village near the mountains,",
"The detective looked at the evidence and realized that",
# ── Code / Technical ──
"def fibonacci(n):\n \"\"\"Calculate the nth Fibonacci number.\"\"\"\n",
"The most important data structures in programming are",
# ── Short completion ──
"The capital of France is",
"Water boils at a temperature of",
# ── Long context ──
("Artificial intelligence has transformed many industries. "
"In healthcare, AI is used for diagnosis and drug discovery. "
"In finance, it powers algorithmic trading and fraud detection. "
"Looking ahead, the most promising application of AI is"),
]
def __init__(self, config: EvalConfig):
self.config = config
@torch.no_grad()
def generate_samples(
self,
model: nn.Module,
tokenizer: Any,
device: torch.device,
prompts: Optional[List[str]] = None,
verbose: bool = True,
) -> List[Dict[str, Any]]:
"""Generates text for each prompt.
Returns:
[{"prompt": str, "generations": [str, ...], "metrics": {...}}, ...]
"""
model.eval()
prompts = prompts or self.DEFAULT_PROMPTS
results = []
if verbose:
print("\n" + "=" * 70)
print("πŸ“ Text Generation Evaluation")
print("=" * 70)
for idx, prompt in enumerate(prompts):
prompt_results = {
"prompt": prompt,
"generations": [],
"metrics": {},
}
if verbose:
print(f"\n{'─'*60}")
print(f"Prompt [{idx+1}/{len(prompts)}]:")
print(f" \"{prompt[:80]}{'...' if len(prompt) > 80 else ''}\"")
print(f"{'─'*60}")
# Encode prompt
prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
input_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device)
all_texts = []
for sample_idx in range(self.config.num_samples):
# Generate
generated_ids = model.generate(
input_tensor,
max_new_tokens=self.config.max_new_tokens,
temperature=self.config.temperature,
top_k=self.config.top_k,
top_p=self.config.top_p,
)
# Decode (only the part after the prompt)
new_ids = generated_ids[0][len(prompt_ids):].tolist()
generated_text = tokenizer.decode(new_ids)
all_texts.append(generated_text)
prompt_results["generations"].append(generated_text)
if verbose:
print(f"\n ✍️ Generation #{sample_idx+1}:")
# Clean output (including newlines)
display_text = generated_text[:500]
for line in display_text.split("\n"):
print(f" {line}")
if len(generated_text) > 500:
print(f" ... (total {len(generated_text)} characters)")
# Generation quality metrics
prompt_results["metrics"] = self._compute_generation_metrics(all_texts)
if verbose and prompt_results["metrics"]:
m = prompt_results["metrics"]
print(f"\n πŸ“Š Metrics: "
f"avg_length={m['avg_length']:.0f} chars, "
f"repetition_rate={m['repetition_rate']:.1%}, "
f"lexical_diversity={m['lexical_diversity']:.2f}")
results.append(prompt_results)
return results
@staticmethod
def _compute_generation_metrics(texts: List[str]) -> Dict[str, float]:
"""Computes quality metrics for generated text.
Metrics:
- avg_length: Average generation length (characters)
- avg_word_count: Average word count
- repetition_rate: n-gram repetition rate (lower is better)
- lexical_diversity: Ratio of unique words (higher means more diverse)
- sample_diversity: Diversity across samples (how different are different generations)
"""
if not texts:
return {}
# Length
lengths = [len(t) for t in texts]
word_counts = [len(t.split()) for t in texts]
# Repetition rate (based on 4-grams)
rep_rates = []
for text in texts:
words = text.lower().split()
if len(words) < 4:
rep_rates.append(0.0)
continue
ngrams = [tuple(words[i:i+4]) for i in range(len(words)-3)]
unique_ratio = len(set(ngrams)) / len(ngrams) if ngrams else 1.0
rep_rates.append(1.0 - unique_ratio) # repetition rate = 1 - unique ratio
# Lexical diversity (Type-Token Ratio)
diversities = []
for text in texts:
words = text.lower().split()
if words:
diversities.append(len(set(words)) / len(words))
else:
diversities.append(0.0)
# Inter-sample diversity (inverse of Jaccard similarity)
sample_div = 0.0
if len(texts) > 1:
word_sets = [set(t.lower().split()) for t in texts]
similarities = []
for i in range(len(word_sets)):
for j in range(i+1, len(word_sets)):
inter = len(word_sets[i] & word_sets[j])
union = len(word_sets[i] | word_sets[j])
if union > 0:
similarities.append(inter / union)
sample_div = 1.0 - (sum(similarities) / max(len(similarities), 1))
return {
"avg_length": sum(lengths) / len(lengths),
"avg_word_count": sum(word_counts) / len(word_counts),
"repetition_rate": sum(rep_rates) / len(rep_rates),
"lexical_diversity": sum(diversities) / len(diversities),
"sample_diversity": round(sample_div, 3),
}