mini-llava-demo / src /dataset.py
AD-Styles's picture
Upload folder using huggingface_hub
95e4119 verified
"""VQA-style dataset + collator.
manifest.json 형식:
[
{"image": "path/to/img.jpg", "question": "...", "answer": "..."},
...
]
"""
from __future__ import annotations
import json
from dataclasses import dataclass
from typing import List
import torch
from PIL import Image
from torch.utils.data import Dataset
from transformers import CLIPImageProcessor, PreTrainedTokenizerBase
from .config import IGNORE_INDEX, IMAGE_TOKEN, SYSTEM_PROMPT
def _build_messages(question: str, answer: str | None = None):
"""Qwen2.5 chat template 형식의 messages list.
user 메시지에 <image>\\n 을 prepend 하여 이미지 위치를 명시한다.
"""
user_content = f"{IMAGE_TOKEN}\n{question}"
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_content},
]
if answer is not None:
messages.append({"role": "assistant", "content": answer})
return messages
def encode_for_training(
tokenizer: PreTrainedTokenizerBase,
question: str,
answer: str,
max_length: int = 256,
):
"""학습용: full conversation + instruction-only label masking.
답변(assistant) 토큰만 loss를 받고, 그 이전(system+user)은 IGNORE_INDEX 처리.
"""
full_msgs = _build_messages(question, answer)
prompt_msgs = _build_messages(question, answer=None)
full_text = tokenizer.apply_chat_template(
full_msgs, tokenize=False, add_generation_prompt=False
)
prompt_text = tokenizer.apply_chat_template(
prompt_msgs, tokenize=False, add_generation_prompt=True
)
full = tokenizer(
full_text, max_length=max_length, truncation=True, return_tensors="pt"
)
prompt = tokenizer(prompt_text, truncation=True, return_tensors="pt")
input_ids = full["input_ids"][0]
attention_mask = full["attention_mask"][0]
labels = input_ids.clone()
prompt_len = min(prompt["input_ids"].shape[1], len(labels))
labels[:prompt_len] = IGNORE_INDEX
return input_ids, attention_mask, labels
def encode_for_inference(
tokenizer: PreTrainedTokenizerBase, question: str, max_length: int = 256
):
"""추론용: prompt까지만 (assistant 응답 시작 직전)."""
prompt_msgs = _build_messages(question, answer=None)
prompt_text = tokenizer.apply_chat_template(
prompt_msgs, tokenize=False, add_generation_prompt=True
)
enc = tokenizer(
prompt_text, max_length=max_length, truncation=True, return_tensors="pt"
)
return enc["input_ids"][0], enc["attention_mask"][0]
class VQADataset(Dataset):
def __init__(
self,
manifest_path: str,
tokenizer: PreTrainedTokenizerBase,
image_processor: CLIPImageProcessor,
max_length: int = 256,
):
with open(manifest_path, "r", encoding="utf-8") as f:
self.samples = json.load(f)
self.tokenizer = tokenizer
self.image_processor = image_processor
self.max_length = max_length
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int):
s = self.samples[idx]
image = Image.open(s["image"]).convert("RGB")
pixel_values = self.image_processor(image, return_tensors="pt")[
"pixel_values"
][0]
input_ids, attention_mask, labels = encode_for_training(
self.tokenizer, s["question"], s["answer"], self.max_length
)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"pixel_values": pixel_values,
}
@dataclass
class VQACollator:
"""가변 길이 텍스트를 우측 패딩, 이미지는 단순 stack."""
pad_token_id: int
def __call__(self, batch: List[dict]):
max_len = max(item["input_ids"].size(0) for item in batch)
input_ids, attention_mask, labels = [], [], []
for item in batch:
ids = item["input_ids"]
am = item["attention_mask"]
lb = item["labels"]
pad_len = max_len - ids.size(0)
if pad_len > 0:
ids = torch.cat(
[ids, torch.full((pad_len,), self.pad_token_id, dtype=ids.dtype)]
)
am = torch.cat([am, torch.zeros(pad_len, dtype=am.dtype)])
lb = torch.cat(
[lb, torch.full((pad_len,), IGNORE_INDEX, dtype=lb.dtype)]
)
input_ids.append(ids)
attention_mask.append(am)
labels.append(lb)
return {
"input_ids": torch.stack(input_ids),
"attention_mask": torch.stack(attention_mask),
"labels": torch.stack(labels),
"pixel_values": torch.stack([item["pixel_values"] for item in batch]),
}