| """ |
| Dataset for custom fine-tuning pairs (JSON or JSONL). |
| Expected fields: prompt, code, optional language. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| from pathlib import Path |
| from typing import Dict, List |
|
|
| from torch.utils.data import Dataset |
|
|
| from src.tokenizer.code_tokenizer import CodeTokenizer |
|
|
|
|
| class CustomPairDataset(Dataset): |
| def __init__(self, path: str, tokenizer: CodeTokenizer, max_seq_len: int = 512) -> None: |
| self.path = Path(path) |
| if not self.path.exists(): |
| raise FileNotFoundError(f"Custom fine-tune data file not found: {self.path}") |
| self.tokenizer = tokenizer |
| self.max_seq_len = max_seq_len |
| self.rows: List[List[int]] = [] |
| self._load() |
|
|
| def _load(self) -> None: |
| if self.path.suffix.lower() == ".jsonl": |
| data = [] |
| for line in self.path.read_text(encoding="utf-8-sig").splitlines(): |
| line = line.strip().lstrip("\ufeff") |
| if not line: |
| continue |
| data.append(json.loads(line)) |
| elif self.path.suffix.lower() == ".json": |
| raw = json.loads(self.path.read_text(encoding="utf-8-sig")) |
| if isinstance(raw, dict) and "data" in raw: |
| data = raw["data"] |
| elif isinstance(raw, list): |
| data = raw |
| else: |
| raise ValueError("JSON fine-tune file must be a list or {'data': [...]}.") |
| else: |
| raise ValueError("Custom fine-tune file must be .json or .jsonl") |
|
|
| for row in data: |
| prompt = str(row.get("prompt", "")).strip() |
| code = str(row.get("code", "")).strip() |
| language = str(row.get("language", "python")).strip().lower() or "python" |
| if not prompt or not code: |
| continue |
| text = self.tokenizer.format_training_sample(prompt=prompt, code=code, language=language) |
| ids = self.tokenizer.encode(text)[: self.max_seq_len] |
| if len(ids) >= 8: |
| self.rows.append(ids) |
|
|
| if not self.rows: |
| raise ValueError("No valid samples found in custom fine-tune data.") |
|
|
| def __len__(self) -> int: |
| return len(self.rows) |
|
|
| def __getitem__(self, idx: int) -> List[int]: |
| return self.rows[idx] |
|
|
|
|