bbkdevops's picture
download
raw
2.4 kB
"""Dataset สำหรับ QA chatbot — format เป็น chat template"""
import json
import torch
from pathlib import Path
from torch.utils.data import Dataset
from tokenizers import Tokenizer
CHAT_TEMPLATE = "<bos><system>คุณคือ TinyMind ผู้ช่วย AI ที่ฉลาดและตอบถูกต้องเสมอ</system>\n<user>{question}</user>\n<assistant>{answer}<eos>"
class QADataset(Dataset):
def __init__(self, path: str | Path, tokenizer: Tokenizer, max_len: int = 1024):
self.tokenizer = tokenizer
self.max_len = max_len
self.samples: list[dict] = []
with open(path, encoding="utf-8") as f:
for line in f:
try:
item = json.loads(line.strip())
if item.get("question") and item.get("answer"):
self.samples.append(item)
except Exception:
pass
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
item = self.samples[idx]
text = CHAT_TEMPLATE.format(
question=item["question"],
answer=item["answer"],
)
enc = self.tokenizer.encode(text)
ids = enc.ids[:self.max_len]
input_ids = torch.tensor(ids, dtype=torch.long)
labels = input_ids.clone()
# Mask loss on prompt — เรียนรู้เฉพาะ answer
# หา position ของ <assistant> token
assistant_id = self.tokenizer.token_to_id("<assistant>")
if assistant_id is not None and assistant_id in ids:
sep_pos = ids.index(assistant_id) + 1
labels[:sep_pos] = -100
return {"input_ids": input_ids, "labels": labels}
def collate_fn(batch: list[dict], pad_id: int = 0) -> dict[str, torch.Tensor]:
max_len = max(x["input_ids"].shape[0] for x in batch)
input_ids = torch.full((len(batch), max_len), pad_id, dtype=torch.long)
labels = torch.full((len(batch), max_len), -100, dtype=torch.long)
for i, x in enumerate(batch):
n = x["input_ids"].shape[0]
input_ids[i, :n] = x["input_ids"]
labels[i, :n] = x["labels"]
attention_mask = (input_ids != pad_id).long()
return {"input_ids": input_ids, "labels": labels, "attention_mask": attention_mask}

Xet Storage Details

Size:
2.4 kB
·
Xet hash:
12ec6ef7ba45a631790799f9fc545e66c1a95c140581b48379994fd792fff021

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.