File size: 2,323 Bytes
53f0cc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""
Dataset for custom fine-tuning pairs (JSON or JSONL).
Expected fields: prompt, code, optional language.
"""

from __future__ import annotations

import json
from pathlib import Path
from typing import Dict, List

from torch.utils.data import Dataset

from src.tokenizer.code_tokenizer import CodeTokenizer


class CustomPairDataset(Dataset):
    def __init__(self, path: str, tokenizer: CodeTokenizer, max_seq_len: int = 512) -> None:
        self.path = Path(path)
        if not self.path.exists():
            raise FileNotFoundError(f"Custom fine-tune data file not found: {self.path}")
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len
        self.rows: List[List[int]] = []
        self._load()

    def _load(self) -> None:
        if self.path.suffix.lower() == ".jsonl":
            data = []
            for line in self.path.read_text(encoding="utf-8-sig").splitlines():
                line = line.strip().lstrip("\ufeff")
                if not line:
                    continue
                data.append(json.loads(line))
        elif self.path.suffix.lower() == ".json":
            raw = json.loads(self.path.read_text(encoding="utf-8-sig"))
            if isinstance(raw, dict) and "data" in raw:
                data = raw["data"]
            elif isinstance(raw, list):
                data = raw
            else:
                raise ValueError("JSON fine-tune file must be a list or {'data': [...]}.")
        else:
            raise ValueError("Custom fine-tune file must be .json or .jsonl")

        for row in data:
            prompt = str(row.get("prompt", "")).strip()
            code = str(row.get("code", "")).strip()
            language = str(row.get("language", "python")).strip().lower() or "python"
            if not prompt or not code:
                continue
            text = self.tokenizer.format_training_sample(prompt=prompt, code=code, language=language)
            ids = self.tokenizer.encode(text)[: self.max_seq_len]
            if len(ids) >= 8:
                self.rows.append(ids)

        if not self.rows:
            raise ValueError("No valid samples found in custom fine-tune data.")

    def __len__(self) -> int:
        return len(self.rows)

    def __getitem__(self, idx: int) -> List[int]:
        return self.rows[idx]