File size: 7,278 Bytes
278d278
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import torch
import os
import re
from torch.utils.data import DataLoader
from datasets import load_dataset
from tqdm import tqdm
from .clean_turkish_data import get_clean_loader, CleanTurkishDataset

def prepare_dictionary_data(data_dir="./data"):
    output_path = os.path.join(data_dir, "stage1_dictionary.bin")
    if os.path.exists(output_path):
        return output_path
        
    print("[Curriculum] Downloading Dictionary Dataset (Stage 1)...")
    
    # Try TDK dataset with specific file to avoid column mismatch
    try:
        print("[Curriculum] Trying 'erogluegemen/TDK_Turkish_Words' (word meanings only)...")
        dataset = load_dataset(
            "erogluegemen/TDK_Turkish_Words", 
            data_files="tdk_word_meaning_data.csv",
            split="train"
        )
        
        collected_bytes = []
        print("[Curriculum] Processing Dictionary...")
        for item in tqdm(dataset):
            # This CSV has: 'madde' (word), 'anlam' (meaning)
            word = str(item.get('madde', '')).strip()
            meaning = str(item.get('anlam', '')).strip()
            
            if word and meaning and len(word) > 0 and len(meaning) > 0:
                text = f"{word}: {meaning}.\n\n"
                collected_bytes.append(text.encode('utf-8'))
                
        if len(collected_bytes) == 0:
            raise Exception("No valid entries found in dataset")
            
        full_data = b"".join(collected_bytes)
        with open(output_path, "wb") as f:
            f.write(full_data)
            
        print(f"[Curriculum] Stage 1 Data Ready: {len(full_data)/1e6:.1f}MB")
        return output_path
        
    except Exception as e:
        print(f"⚠️ Dictionary dataset failed: {e}")
        print("Fallback: Using clean Wikipedia data for Stage 1")
        return None

def prepare_stories_data(data_dir="./data"):
    output_path = os.path.join(data_dir, "stage2_stories.bin")
    if os.path.exists(output_path):
        return output_path
        
    print("[Curriculum] Downloading Children Stories Dataset (Stage 2)...")
    try:
        # Try to load the specific dataset mentioned in RFC
        # If it doesn't exist, we might need a fallback or a different one
        dataset = load_dataset("turkish-children-stories", split="train")
        
        collected_bytes = []
        print("[Curriculum] Processing Stories...")
        for item in tqdm(dataset):
            text = item.get('text', '').strip()
            if text:
                collected_bytes.append(text.encode('utf-8'))
                collected_bytes.append(b'\n\n')
                
        full_data = b"".join(collected_bytes)
        with open(output_path, "wb") as f:
            f.write(full_data)
            
        print(f"[Curriculum] Stage 2 Data Ready: {len(full_data)/1e6:.1f}MB")
        return output_path
        
    except Exception as e:
        print(f"⚠️ Failed to load stories dataset: {e}")
        print("Fallback: Creating synthetic simple dataset from Wikipedia (Stage 2)")
        
        # Fallback: Load Wikipedia and filter for simple/short sentences
        try:
            wiki_path = os.path.join(data_dir, "trwiki_clean_train.bin")
            if not os.path.exists(wiki_path):
                from .clean_turkish_data import prepare_clean_turkish_data
                prepare_clean_turkish_data(data_dir)
            
            # Read wiki data
            with open(wiki_path, "rb") as f:
                wiki_data = f.read()
            
            # Decode a chunk to filter (processing 150MB is too much for simple fallback logic in memory)
            # We'll just take the first 20MB and pretend it's simple for now to avoid OOM
            # In a real scenario, we'd process line by line.
            limit = 20 * 1024 * 1024
            simple_data = wiki_data[:limit] 
            
            with open(output_path, "wb") as f:
                f.write(simple_data)
                
            return output_path
        except Exception as e2:
            print(f"Fallback failed: {e2}")
            return None

class CurriculumDataLoader:
    """
    Manages the data curriculum for AGIFORMER Phase 7.
    Switches between data sources based on training progress.
    """
    def __init__(self, data_dir, batch_size, seq_len, max_steps):
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.seq_len = seq_len
        self.max_steps = max_steps
        self.current_stage = 0
        self.loaders = {}
        
    def _get_stage(self, step):
        progress = step / self.max_steps
        if progress < 0.15:
            return 1 # Lexical Grounding
        elif progress < 0.40:
            return 2 # Syntactic Scaffolding
        else:
            return 3 # Semantic Expansion

    def get_loader(self, step):
        stage = self._get_stage(step)
        
        # If stage changed or loader not initialized
        if stage not in self.loaders:
            self.loaders[stage] = self._create_loader_for_stage(stage)
            
        return self.loaders[stage]
    
    def _create_loader_for_stage(self, stage):
        if stage == 1:
            print(f"\n[Curriculum] Initializing Stage 1: Lexical Grounding (Dictionary)")
            path = prepare_dictionary_data(self.data_dir)
            if path:
                dataset = CleanTurkishDataset(path, self.seq_len)
                return DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=0, pin_memory=True)
            else:
                return get_clean_loader(self.data_dir, self.batch_size, self.seq_len, split="train")
            
        elif stage == 2:
            print(f"\n[Curriculum] Initializing Stage 2: Syntactic Scaffolding (Children Stories)")
            path = prepare_stories_data(self.data_dir)
            if path:
                dataset = CleanTurkishDataset(path, self.seq_len)
                return DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=0, pin_memory=True)
            else:
                return get_clean_loader(self.data_dir, self.batch_size, self.seq_len, split="train")
            
        elif stage == 3:
            print(f"\n[Curriculum] Initializing Stage 3: Semantic Expansion (Wikipedia)")
            return get_clean_loader(self.data_dir, self.batch_size, self.seq_len, split="train")
            
    def check_stage_change(self, step):
        """Returns True if the stage has changed at this step."""
        new_stage = self._get_stage(step)
        if new_stage != self.current_stage:
            print(f"\n*** CURRICULUM ALERT: Advancing to Stage {new_stage} ***")
            self.current_stage = new_stage
            return True
        return False

    def get_plasticity_alpha(self, step):
        """
        Returns the plasticity coefficient (alpha) based on the schedule.
        
        Stage 1 (Childhood): 0.1 (High plasticity, fast forgetting)
        Stage 2 (Youth): 0.5 (Balanced)
        Stage 3 (Adulthood): 0.99 (Low plasticity, stable memory)
        """
        stage = self._get_stage(step)
        
        if stage == 1:
            return 0.1
        elif stage == 2:
            return 0.5
        else:
            return 0.99