|
|
import torch |
|
|
import os |
|
|
import re |
|
|
from torch.utils.data import DataLoader |
|
|
from datasets import load_dataset |
|
|
from tqdm import tqdm |
|
|
from .clean_turkish_data import get_clean_loader, CleanTurkishDataset |
|
|
|
|
|
def prepare_dictionary_data(data_dir="./data"): |
|
|
output_path = os.path.join(data_dir, "stage1_dictionary.bin") |
|
|
if os.path.exists(output_path): |
|
|
return output_path |
|
|
|
|
|
print("[Curriculum] Downloading Dictionary Dataset (Stage 1)...") |
|
|
|
|
|
|
|
|
try: |
|
|
print("[Curriculum] Trying 'erogluegemen/TDK_Turkish_Words' (word meanings only)...") |
|
|
dataset = load_dataset( |
|
|
"erogluegemen/TDK_Turkish_Words", |
|
|
data_files="tdk_word_meaning_data.csv", |
|
|
split="train" |
|
|
) |
|
|
|
|
|
collected_bytes = [] |
|
|
print("[Curriculum] Processing Dictionary...") |
|
|
for item in tqdm(dataset): |
|
|
|
|
|
word = str(item.get('madde', '')).strip() |
|
|
meaning = str(item.get('anlam', '')).strip() |
|
|
|
|
|
if word and meaning and len(word) > 0 and len(meaning) > 0: |
|
|
text = f"{word}: {meaning}.\n\n" |
|
|
collected_bytes.append(text.encode('utf-8')) |
|
|
|
|
|
if len(collected_bytes) == 0: |
|
|
raise Exception("No valid entries found in dataset") |
|
|
|
|
|
full_data = b"".join(collected_bytes) |
|
|
with open(output_path, "wb") as f: |
|
|
f.write(full_data) |
|
|
|
|
|
print(f"[Curriculum] Stage 1 Data Ready: {len(full_data)/1e6:.1f}MB") |
|
|
return output_path |
|
|
|
|
|
except Exception as e: |
|
|
print(f"⚠️ Dictionary dataset failed: {e}") |
|
|
print("Fallback: Using clean Wikipedia data for Stage 1") |
|
|
return None |
|
|
|
|
|
def prepare_stories_data(data_dir="./data"): |
|
|
output_path = os.path.join(data_dir, "stage2_stories.bin") |
|
|
if os.path.exists(output_path): |
|
|
return output_path |
|
|
|
|
|
print("[Curriculum] Downloading Children Stories Dataset (Stage 2)...") |
|
|
try: |
|
|
|
|
|
|
|
|
dataset = load_dataset("turkish-children-stories", split="train") |
|
|
|
|
|
collected_bytes = [] |
|
|
print("[Curriculum] Processing Stories...") |
|
|
for item in tqdm(dataset): |
|
|
text = item.get('text', '').strip() |
|
|
if text: |
|
|
collected_bytes.append(text.encode('utf-8')) |
|
|
collected_bytes.append(b'\n\n') |
|
|
|
|
|
full_data = b"".join(collected_bytes) |
|
|
with open(output_path, "wb") as f: |
|
|
f.write(full_data) |
|
|
|
|
|
print(f"[Curriculum] Stage 2 Data Ready: {len(full_data)/1e6:.1f}MB") |
|
|
return output_path |
|
|
|
|
|
except Exception as e: |
|
|
print(f"⚠️ Failed to load stories dataset: {e}") |
|
|
print("Fallback: Creating synthetic simple dataset from Wikipedia (Stage 2)") |
|
|
|
|
|
|
|
|
try: |
|
|
wiki_path = os.path.join(data_dir, "trwiki_clean_train.bin") |
|
|
if not os.path.exists(wiki_path): |
|
|
from .clean_turkish_data import prepare_clean_turkish_data |
|
|
prepare_clean_turkish_data(data_dir) |
|
|
|
|
|
|
|
|
with open(wiki_path, "rb") as f: |
|
|
wiki_data = f.read() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
limit = 20 * 1024 * 1024 |
|
|
simple_data = wiki_data[:limit] |
|
|
|
|
|
with open(output_path, "wb") as f: |
|
|
f.write(simple_data) |
|
|
|
|
|
return output_path |
|
|
except Exception as e2: |
|
|
print(f"Fallback failed: {e2}") |
|
|
return None |
|
|
|
|
|
class CurriculumDataLoader: |
|
|
""" |
|
|
Manages the data curriculum for AGIFORMER Phase 7. |
|
|
Switches between data sources based on training progress. |
|
|
""" |
|
|
def __init__(self, data_dir, batch_size, seq_len, max_steps): |
|
|
self.data_dir = data_dir |
|
|
self.batch_size = batch_size |
|
|
self.seq_len = seq_len |
|
|
self.max_steps = max_steps |
|
|
self.current_stage = 0 |
|
|
self.loaders = {} |
|
|
|
|
|
def _get_stage(self, step): |
|
|
progress = step / self.max_steps |
|
|
if progress < 0.15: |
|
|
return 1 |
|
|
elif progress < 0.40: |
|
|
return 2 |
|
|
else: |
|
|
return 3 |
|
|
|
|
|
def get_loader(self, step): |
|
|
stage = self._get_stage(step) |
|
|
|
|
|
|
|
|
if stage not in self.loaders: |
|
|
self.loaders[stage] = self._create_loader_for_stage(stage) |
|
|
|
|
|
return self.loaders[stage] |
|
|
|
|
|
def _create_loader_for_stage(self, stage): |
|
|
if stage == 1: |
|
|
print(f"\n[Curriculum] Initializing Stage 1: Lexical Grounding (Dictionary)") |
|
|
path = prepare_dictionary_data(self.data_dir) |
|
|
if path: |
|
|
dataset = CleanTurkishDataset(path, self.seq_len) |
|
|
return DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=0, pin_memory=True) |
|
|
else: |
|
|
return get_clean_loader(self.data_dir, self.batch_size, self.seq_len, split="train") |
|
|
|
|
|
elif stage == 2: |
|
|
print(f"\n[Curriculum] Initializing Stage 2: Syntactic Scaffolding (Children Stories)") |
|
|
path = prepare_stories_data(self.data_dir) |
|
|
if path: |
|
|
dataset = CleanTurkishDataset(path, self.seq_len) |
|
|
return DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=0, pin_memory=True) |
|
|
else: |
|
|
return get_clean_loader(self.data_dir, self.batch_size, self.seq_len, split="train") |
|
|
|
|
|
elif stage == 3: |
|
|
print(f"\n[Curriculum] Initializing Stage 3: Semantic Expansion (Wikipedia)") |
|
|
return get_clean_loader(self.data_dir, self.batch_size, self.seq_len, split="train") |
|
|
|
|
|
def check_stage_change(self, step): |
|
|
"""Returns True if the stage has changed at this step.""" |
|
|
new_stage = self._get_stage(step) |
|
|
if new_stage != self.current_stage: |
|
|
print(f"\n*** CURRICULUM ALERT: Advancing to Stage {new_stage} ***") |
|
|
self.current_stage = new_stage |
|
|
return True |
|
|
return False |
|
|
|
|
|
def get_plasticity_alpha(self, step): |
|
|
""" |
|
|
Returns the plasticity coefficient (alpha) based on the schedule. |
|
|
|
|
|
Stage 1 (Childhood): 0.1 (High plasticity, fast forgetting) |
|
|
Stage 2 (Youth): 0.5 (Balanced) |
|
|
Stage 3 (Adulthood): 0.99 (Low plasticity, stable memory) |
|
|
""" |
|
|
stage = self._get_stage(step) |
|
|
|
|
|
if stage == 1: |
|
|
return 0.1 |
|
|
elif stage == 2: |
|
|
return 0.5 |
|
|
else: |
|
|
return 0.99 |
|
|
|