File size: 7,706 Bytes
858e8b2
8a58ffe
 
 
 
 
 
 
 
 
 
858e8b2
 
 
 
 
 
 
 
 
 
 
 
 
 
8a58ffe
 
858e8b2
8a58ffe
858e8b2
8a58ffe
 
 
 
858e8b2
8a58ffe
 
 
 
858e8b2
8a58ffe
 
 
858e8b2
8a58ffe
 
 
858e8b2
8a58ffe
 
 
858e8b2
8a58ffe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
858e8b2
8a58ffe
 
 
 
 
 
 
 
 
 
858e8b2
8a58ffe
 
 
 
 
 
 
 
 
 
 
858e8b2
8a58ffe
 
 
858e8b2
8a58ffe
 
 
 
 
858e8b2
8a58ffe
 
 
 
 
 
 
 
858e8b2
8a58ffe
 
 
 
 
 
 
858e8b2
 
8a58ffe
 
 
 
858e8b2
8a58ffe
858e8b2
8a58ffe
 
 
 
858e8b2
 
 
 
8a58ffe
 
 
 
 
 
 
858e8b2
 
 
 
 
 
 
 
8a58ffe
 
 
 
858e8b2
8a58ffe
 
 
858e8b2
8a58ffe
 
 
 
 
 
 
 
858e8b2
8a58ffe
858e8b2
8a58ffe
 
 
 
 
 
 
 
858e8b2
8a58ffe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
"""Text generation evaluator."""

from typing import Any, Dict, List, Optional

import torch
import torch.nn as nn

from llm_lab.config import EvalConfig


class GenerationEvaluator:
    """Evaluates text quality by generating from various prompts.

    Evaluation perspectives:
      1) Grammatical accuracy:  Does it generate grammatically correct English sentences?
      2) Coherence:             Does it maintain context continuity?
      3) Diversity:             Does it produce different outputs for the same prompt?
      4) Repetition avoidance:  Does it avoid repeating the same phrases?
      5) Knowledge expression:  Is knowledge from the training data reflected?

    Realistic expectations for a 1B model:
      - Generates grammatically correct English sentences βœ…
      - Maintains coherence within short paragraphs βœ…
      - Complex reasoning or extended logical chains ❌ (requires a larger model)
      - Factual accuracy is not guaranteed ⚠️
    """

    # Test prompts from various domains
    DEFAULT_PROMPTS = [
        # ── General knowledge ──
        "The theory of relativity states that",
        "In the history of computer science,",
        "The human brain is remarkable because",

        # ── Explanation / Education ──
        "To understand machine learning, one must first",
        "The water cycle begins when",
        "Photosynthesis is the process by which",

        # ── Narrative / Story ──
        "Once upon a time, in a small village near the mountains,",
        "The detective looked at the evidence and realized that",

        # ── Code / Technical ──
        "def fibonacci(n):\n    \"\"\"Calculate the nth Fibonacci number.\"\"\"\n",
        "The most important data structures in programming are",

        # ── Short completion ──
        "The capital of France is",
        "Water boils at a temperature of",

        # ── Long context ──
        ("Artificial intelligence has transformed many industries. "
         "In healthcare, AI is used for diagnosis and drug discovery. "
         "In finance, it powers algorithmic trading and fraud detection. "
         "Looking ahead, the most promising application of AI is"),
    ]

    def __init__(self, config: EvalConfig):
        self.config = config

    @torch.no_grad()
    def generate_samples(
        self,
        model: nn.Module,
        tokenizer: Any,
        device: torch.device,
        prompts: Optional[List[str]] = None,
        verbose: bool = True,
    ) -> List[Dict[str, Any]]:
        """Generates text for each prompt.

        Returns:
            [{"prompt": str, "generations": [str, ...], "metrics": {...}}, ...]
        """
        model.eval()
        prompts = prompts or self.DEFAULT_PROMPTS
        results = []

        if verbose:
            print("\n" + "=" * 70)
            print("πŸ“ Text Generation Evaluation")
            print("=" * 70)

        for idx, prompt in enumerate(prompts):
            prompt_results = {
                "prompt": prompt,
                "generations": [],
                "metrics": {},
            }

            if verbose:
                print(f"\n{'─'*60}")
                print(f"Prompt [{idx+1}/{len(prompts)}]:")
                print(f"  \"{prompt[:80]}{'...' if len(prompt) > 80 else ''}\"")
                print(f"{'─'*60}")

            # Encode prompt
            prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
            input_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device)

            all_texts = []
            for sample_idx in range(self.config.num_samples):
                # Generate
                generated_ids = model.generate(
                    input_tensor,
                    max_new_tokens=self.config.max_new_tokens,
                    temperature=self.config.temperature,
                    top_k=self.config.top_k,
                    top_p=self.config.top_p,
                )

                # Decode (only the part after the prompt)
                new_ids = generated_ids[0][len(prompt_ids):].tolist()
                generated_text = tokenizer.decode(new_ids)
                all_texts.append(generated_text)

                prompt_results["generations"].append(generated_text)

                if verbose:
                    print(f"\n  ✍️ Generation #{sample_idx+1}:")
                    # Clean output (including newlines)
                    display_text = generated_text[:500]
                    for line in display_text.split("\n"):
                        print(f"    {line}")
                    if len(generated_text) > 500:
                        print(f"    ... (total {len(generated_text)} characters)")

            # Generation quality metrics
            prompt_results["metrics"] = self._compute_generation_metrics(all_texts)

            if verbose and prompt_results["metrics"]:
                m = prompt_results["metrics"]
                print(f"\n  πŸ“Š Metrics: "
                      f"avg_length={m['avg_length']:.0f} chars, "
                      f"repetition_rate={m['repetition_rate']:.1%}, "
                      f"lexical_diversity={m['lexical_diversity']:.2f}")

            results.append(prompt_results)

        return results

    @staticmethod
    def _compute_generation_metrics(texts: List[str]) -> Dict[str, float]:
        """Computes quality metrics for generated text.

        Metrics:
          - avg_length:        Average generation length (characters)
          - avg_word_count:    Average word count
          - repetition_rate:   n-gram repetition rate (lower is better)
          - lexical_diversity: Ratio of unique words (higher means more diverse)
          - sample_diversity:  Diversity across samples (how different are different generations)
        """
        if not texts:
            return {}

        # Length
        lengths = [len(t) for t in texts]
        word_counts = [len(t.split()) for t in texts]

        # Repetition rate (based on 4-grams)
        rep_rates = []
        for text in texts:
            words = text.lower().split()
            if len(words) < 4:
                rep_rates.append(0.0)
                continue
            ngrams = [tuple(words[i:i+4]) for i in range(len(words)-3)]
            unique_ratio = len(set(ngrams)) / len(ngrams) if ngrams else 1.0
            rep_rates.append(1.0 - unique_ratio)  # repetition rate = 1 - unique ratio

        # Lexical diversity (Type-Token Ratio)
        diversities = []
        for text in texts:
            words = text.lower().split()
            if words:
                diversities.append(len(set(words)) / len(words))
            else:
                diversities.append(0.0)

        # Inter-sample diversity (inverse of Jaccard similarity)
        sample_div = 0.0
        if len(texts) > 1:
            word_sets = [set(t.lower().split()) for t in texts]
            similarities = []
            for i in range(len(word_sets)):
                for j in range(i+1, len(word_sets)):
                    inter = len(word_sets[i] & word_sets[j])
                    union = len(word_sets[i] | word_sets[j])
                    if union > 0:
                        similarities.append(inter / union)
            sample_div = 1.0 - (sum(similarities) / max(len(similarities), 1))

        return {
            "avg_length": sum(lengths) / len(lengths),
            "avg_word_count": sum(word_counts) / len(word_counts),
            "repetition_rate": sum(rep_rates) / len(rep_rates),
            "lexical_diversity": sum(diversities) / len(diversities),
            "sample_diversity": round(sample_div, 3),
        }