| """ |
| dataset.py β Cross-Script Translation Fix |
| ========================================== |
| INPUT : quote_text (Roman/IAST transliteration of Sanskrit) |
| TARGET : quote_devanagari (Devanagari script) |
| |
| This is the CORRECT task: the model learns to transliterate / translate |
| Roman Sanskrit β Devanagari, which is a meaningful, learnable mapping |
| (far better than devanagariβdevanagari reconstruction which teaches nothing). |
| |
| KEY CHANGES from original: |
| 1. _input_field = 'quote_text' (was 'quote_devanagari') |
| 2. _target_field = 'quote_devanagari' (unchanged) |
| 3. Separate source/target tokenizers β Roman and Devanagari have |
| completely different character sets; a shared BPE vocab forces the |
| model to learn both scripts in one embedding table, which wastes |
| capacity and confuses the attention mechanism. |
| 4. Negative example generation fixed β reversal now operates on |
| DEVANAGARI target only (not accidentally on Roman source). |
| 5. curriculum_sort uses target length (Devanagari) for difficulty proxy. |
| """ |
|
|
| from datasets import load_dataset |
| from torch.utils.data import Dataset |
| import torch |
| import torch.nn.functional as F |
| import random |
|
|
|
|
| class OptimizedSanskritDataset(Dataset): |
| def __init__(self, split='train', tokenizer=None, max_len=80, cfg=None, |
| src_tokenizer=None, tgt_tokenizer=None): |
| """ |
| Args: |
| tokenizer : shared tokenizer (legacy β used if src/tgt not provided) |
| src_tokenizer : tokenizer for quote_text (Roman script) |
| tgt_tokenizer : tokenizer for quote_devanagari (Devanagari script) |
| If None, falls back to shared `tokenizer`. |
| """ |
| from config import CONFIG |
| self.cfg = cfg or CONFIG |
| self.max_len = max_len |
| self.pad_id = 1 |
| self.mask_id = self.cfg['diffusion']['mask_token_id'] |
| self.include_negatives = self.cfg['data']['include_negative_examples'] |
|
|
| |
| |
| self.src_tokenizer = src_tokenizer or tokenizer |
| self.tgt_tokenizer = tgt_tokenizer or tokenizer |
|
|
| if self.src_tokenizer is None: |
| raise ValueError("Provide at least one tokenizer.") |
|
|
| print(f"π₯ Loading '{split}' split β¦") |
| raw = load_dataset("paws/sanskrit-verses-gretil", split=split) |
| cols = raw.column_names |
|
|
| |
| if 'quote_text' in cols and 'quote_devanagari' in cols: |
| |
| self._input_field = 'quote_text' |
| self._target_field = 'quote_devanagari' |
| print(" Format: quote_text (Roman) β quote_devanagari (Devanagari) β") |
| elif 'sentence1' in cols and 'sentence2' in cols: |
| |
| self._input_field = 'sentence1' |
| self._target_field = 'sentence2' |
| print(" Format: PAWS sentence pairs β") |
| else: |
| |
| self._input_field = 'quote_devanagari' |
| self._target_field = 'quote_devanagari' |
| print(" β οΈ Format: DevanagariβDevanagari (suboptimal β no quote_text found)") |
|
|
| |
| |
| raw = raw.filter( |
| lambda ex: ( |
| bool(ex[self._input_field].strip()) and |
| bool(ex[self._target_field].strip()) |
| ) |
| ) |
| print(f" After empty-filter: {len(raw)} samples") |
|
|
| self.dataset = raw |
|
|
| if split == 'train': |
| self.dataset = self._curriculum_sort() |
|
|
| print(f"β
{len(self.dataset)} samples loaded.") |
|
|
| |
|
|
| def _encode_src(self, text): |
| """Encode source (Roman) text.""" |
| ids = self.src_tokenizer.encode(text)[:self.max_len] |
| t = torch.tensor(ids, dtype=torch.long) |
| t = F.pad(t, (0, max(0, self.max_len - len(t))), value=self.pad_id) |
| return t |
|
|
| def _encode_tgt(self, text): |
| """Encode target (Devanagari) text.""" |
| ids = self.tgt_tokenizer.encode(text)[:self.max_len] |
| t = torch.tensor(ids, dtype=torch.long) |
| t = F.pad(t, (0, max(0, self.max_len - len(t))), value=self.pad_id) |
| return t |
|
|
| |
|
|
| def _curriculum_sort(self): |
| """Short, common Devanagari targets first β long, rare targets last.""" |
| scores = [] |
| for s in self.dataset: |
| text = s[self._target_field] |
| length = len(text.split()) |
| rarity_score = len(set(text)) / max(1, len(text)) |
| scores.append(length * (1 - rarity_score)) |
| order = sorted(range(len(self.dataset)), key=lambda i: scores[i]) |
| return self.dataset.select(order) |
|
|
| |
|
|
| def __len__(self): |
| return len(self.dataset) |
|
|
| def __getitem__(self, idx): |
| sample = self.dataset[idx] |
|
|
| src_text = sample[self._input_field].strip() |
| tgt_text = sample[self._target_field].strip() |
|
|
| input_ids = self._encode_src(src_text) |
| target_ids = self._encode_tgt(tgt_text) |
|
|
| out = { |
| 'input_ids': input_ids, |
| 'target_ids': target_ids, |
| 'input_text': src_text, |
| 'target_text': tgt_text, |
| } |
|
|
| if self.include_negatives: |
| neg_ids = target_ids.clone() |
| |
| non_pad = (neg_ids != self.pad_id).sum().item() |
| if non_pad > 4: |
| i1, i2 = sorted(random.sample(range(non_pad), 2)) |
| neg_ids[i1:i2] = torch.flip(neg_ids[i1:i2], dims=[0]) |
| out['negative_target_ids'] = neg_ids |
|
|
| return out |