Buckets:
| """Dataset สำหรับ QA chatbot — format เป็น chat template""" | |
| import json | |
| import torch | |
| from pathlib import Path | |
| from torch.utils.data import Dataset | |
| from tokenizers import Tokenizer | |
| CHAT_TEMPLATE = "<bos><system>คุณคือ TinyMind ผู้ช่วย AI ที่ฉลาดและตอบถูกต้องเสมอ</system>\n<user>{question}</user>\n<assistant>{answer}<eos>" | |
| class QADataset(Dataset): | |
| def __init__(self, path: str | Path, tokenizer: Tokenizer, max_len: int = 1024): | |
| self.tokenizer = tokenizer | |
| self.max_len = max_len | |
| self.samples: list[dict] = [] | |
| with open(path, encoding="utf-8") as f: | |
| for line in f: | |
| try: | |
| item = json.loads(line.strip()) | |
| if item.get("question") and item.get("answer"): | |
| self.samples.append(item) | |
| except Exception: | |
| pass | |
| def __len__(self) -> int: | |
| return len(self.samples) | |
| def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: | |
| item = self.samples[idx] | |
| text = CHAT_TEMPLATE.format( | |
| question=item["question"], | |
| answer=item["answer"], | |
| ) | |
| enc = self.tokenizer.encode(text) | |
| ids = enc.ids[:self.max_len] | |
| input_ids = torch.tensor(ids, dtype=torch.long) | |
| labels = input_ids.clone() | |
| # Mask loss on prompt — เรียนรู้เฉพาะ answer | |
| # หา position ของ <assistant> token | |
| assistant_id = self.tokenizer.token_to_id("<assistant>") | |
| if assistant_id is not None and assistant_id in ids: | |
| sep_pos = ids.index(assistant_id) + 1 | |
| labels[:sep_pos] = -100 | |
| return {"input_ids": input_ids, "labels": labels} | |
| def collate_fn(batch: list[dict], pad_id: int = 0) -> dict[str, torch.Tensor]: | |
| max_len = max(x["input_ids"].shape[0] for x in batch) | |
| input_ids = torch.full((len(batch), max_len), pad_id, dtype=torch.long) | |
| labels = torch.full((len(batch), max_len), -100, dtype=torch.long) | |
| for i, x in enumerate(batch): | |
| n = x["input_ids"].shape[0] | |
| input_ids[i, :n] = x["input_ids"] | |
| labels[i, :n] = x["labels"] | |
| attention_mask = (input_ids != pad_id).long() | |
| return {"input_ids": input_ids, "labels": labels, "attention_mask": attention_mask} | |
Xet Storage Details
- Size:
- 2.4 kB
- Xet hash:
- 12ec6ef7ba45a631790799f9fc545e66c1a95c140581b48379994fd792fff021
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.