Spaces:
Sleeping
Sleeping
File size: 6,844 Bytes
29e5bf8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | """
dataset.py β Cross-Script Translation Fix
==========================================
INPUT : quote_text (Roman/IAST transliteration of Sanskrit)
TARGET : quote_devanagari (Devanagari script)
This is the CORRECT task: the model learns to transliterate / translate
Roman Sanskrit β Devanagari, which is a meaningful, learnable mapping
(far better than devanagariβdevanagari reconstruction which teaches nothing).
KEY CHANGES from original:
1. _input_field = 'quote_text' (was 'quote_devanagari')
2. _target_field = 'quote_devanagari' (unchanged)
3. Separate source/target tokenizers β Roman and Devanagari have
completely different character sets; a shared BPE vocab forces the
model to learn both scripts in one embedding table, which wastes
capacity and confuses the attention mechanism.
4. Negative example generation fixed β reversal now operates on
DEVANAGARI target only (not accidentally on Roman source).
5. curriculum_sort uses target length (Devanagari) for difficulty proxy.
"""
from datasets import load_dataset
from torch.utils.data import Dataset
import torch
import torch.nn.functional as F
import random
class OptimizedSanskritDataset(Dataset):
def __init__(self, split='train', tokenizer=None, max_len=80, cfg=None,
src_tokenizer=None, tgt_tokenizer=None):
"""
Args:
tokenizer : shared tokenizer (legacy β used if src/tgt not provided)
src_tokenizer : tokenizer for quote_text (Roman script)
tgt_tokenizer : tokenizer for quote_devanagari (Devanagari script)
If None, falls back to shared `tokenizer`.
"""
from config import CONFIG
self.cfg = cfg or CONFIG
self.max_len = max_len
self.pad_id = 1
self.mask_id = self.cfg['diffusion']['mask_token_id']
self.include_negatives = self.cfg['data']['include_negative_examples']
# ββ Tokenizer setup βββββββββββββββββββββββββββββββββββββββββββ
# Support both legacy (shared) and new (separate src/tgt) tokenizers
self.src_tokenizer = src_tokenizer or tokenizer
self.tgt_tokenizer = tgt_tokenizer or tokenizer
if self.src_tokenizer is None:
raise ValueError("Provide at least one tokenizer.")
print(f"π₯ Loading '{split}' split β¦")
raw = load_dataset("paws/sanskrit-verses-gretil", split=split)
cols = raw.column_names
# ββ Field selection βββββββββββββββββββββββββββββββββββββββββββ
if 'quote_text' in cols and 'quote_devanagari' in cols:
# CORRECT setup: Roman input β Devanagari output
self._input_field = 'quote_text'
self._target_field = 'quote_devanagari'
print(" Format: quote_text (Roman) β quote_devanagari (Devanagari) β")
elif 'sentence1' in cols and 'sentence2' in cols:
# PAWS paraphrase pairs fallback
self._input_field = 'sentence1'
self._target_field = 'sentence2'
print(" Format: PAWS sentence pairs β")
else:
# Last resort: same field both sides
self._input_field = 'quote_devanagari'
self._target_field = 'quote_devanagari'
print(" β οΈ Format: DevanagariβDevanagari (suboptimal β no quote_text found)")
# ββ Filter empty rows βββββββββββββββββββββββββββββββββββββββββ
# Some rows have empty quote_text β skip them
raw = raw.filter(
lambda ex: (
bool(ex[self._input_field].strip()) and
bool(ex[self._target_field].strip())
)
)
print(f" After empty-filter: {len(raw)} samples")
self.dataset = raw
if split == 'train':
self.dataset = self._curriculum_sort()
print(f"β
{len(self.dataset)} samples loaded.")
# ββ Encoding ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _encode_src(self, text):
"""Encode source (Roman) text."""
ids = self.src_tokenizer.encode(text)[:self.max_len]
t = torch.tensor(ids, dtype=torch.long)
t = F.pad(t, (0, max(0, self.max_len - len(t))), value=self.pad_id)
return t
def _encode_tgt(self, text):
"""Encode target (Devanagari) text."""
ids = self.tgt_tokenizer.encode(text)[:self.max_len]
t = torch.tensor(ids, dtype=torch.long)
t = F.pad(t, (0, max(0, self.max_len - len(t))), value=self.pad_id)
return t
# ββ Curriculum ββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _curriculum_sort(self):
"""Short, common Devanagari targets first β long, rare targets last."""
scores = []
for s in self.dataset:
text = s[self._target_field]
length = len(text.split())
rarity_score = len(set(text)) / max(1, len(text))
scores.append(length * (1 - rarity_score))
order = sorted(range(len(self.dataset)), key=lambda i: scores[i])
return self.dataset.select(order)
# ββ Item ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
sample = self.dataset[idx]
src_text = sample[self._input_field].strip()
tgt_text = sample[self._target_field].strip()
input_ids = self._encode_src(src_text) # Roman encoded with src_tokenizer
target_ids = self._encode_tgt(tgt_text) # Devanagari encoded with tgt_tokenizer
out = {
'input_ids': input_ids,
'target_ids': target_ids,
'input_text': src_text,
'target_text': tgt_text,
}
if self.include_negatives:
neg_ids = target_ids.clone()
# Reverse a random chunk of the DEVANAGARI target
non_pad = (neg_ids != self.pad_id).sum().item()
if non_pad > 4:
i1, i2 = sorted(random.sample(range(non_pad), 2))
neg_ids[i1:i2] = torch.flip(neg_ids[i1:i2], dims=[0])
out['negative_target_ids'] = neg_ids
return out |