finetune_hf / loader.py
ambivalent02's picture
Upload loader.py with huggingface_hub
e5519c9 verified
# from datasets import load_dataset
# raw_ds = load_dataset("simwit/omni-med-vqa-mini")
# full_dataset = raw_ds["test"]
# split = full_dataset.train_test_split(test_size=0.2, seed=42)
# train_dataset = split["train"]
# eval_dataset = split["test"]
# print("✅ SFT Dataset loaded:")
# print(f" 📚 Train samples: {len(train_dataset)}")
# print(f" 🧪 Eval samples: {len(eval_dataset)}")
# print(f"\n📝 Single Sample: [IMAGE] {train_dataset[0]['question']} {train_dataset[0]['gt_answer']} {train_dataset[0]['image_path']} {list(train_dataset[0].keys())}")
"""
Convert jsonl with `image` and `conversations` into
a HuggingFace Dataset that LFM2-VL expects.
Each sample must contain:
- image : str (absolute path or relative to repo root)
- messages: List[Dict] # openai-style
"""
import json, datasets
from pathlib import Path
from typing import List, Dict
import multiprocessing as mp
from PIL import Image
SYSTEM_MSG = "You are a helpful vision-language assistant."
"""
Convert jsonl with `image` and `conversations` into
a HuggingFace Dataset that works with the medical sample format.
"""
import json, datasets
from pathlib import Path
from typing import List, Dict
import multiprocessing as mp
from PIL import Image
def format_vlm_sample(sample):
"""Format a vlm sample into the expected message format."""
return [
{
"role": "user",
"content": [
{"type": "image", "image": sample["image"]},
{"type": "text", "text": sample["question"]},
],
},
{"role": "assistant", "content": [{"type": "text", "text": sample["gt_answer"]}]},
]
def jsonl_to_dataset_hf_parallel(jsonl_file: str, image_root: str = "", num_workers: int = None):
"""
Fixed parallel version that handles None values properly
"""
if num_workers is None:
num_workers = 8
# Load and validate all lines first
valid_lines = []
with open(jsonl_file, encoding="utf-8") as f:
for line_num, line in enumerate(f):
line = line.strip()
if line: # Skip empty lines
try:
# Quick validation
rec = json.loads(line)
if "image" in rec and "conversations" in rec:
valid_lines.append({"line": line, "image_root": image_root, "line_num": line_num})
except:
print(f"Warning: Line {line_num}: Invalid JSON")
continue
print(f"Found {len(valid_lines)} valid lines to process")
# Create dataset from valid lines
raw_dataset = datasets.Dataset.from_list(valid_lines)
def process_example_safe(example):
"""Process function that never returns None"""
rec = json.loads(example["line"])
image_path = Path(example["image_root"]) / rec["image"]
if not image_path.exists():
# Return a dummy valid entry instead of None
return {
"image": str(image_path.absolute()),
"question": "dummy",
"gt_answer": "dummy",
"valid": False
}
# Extract question and answer
question = ""
gt_answer = ""
for turn in rec["conversations"]:
if turn["from"] == "human":
question = turn["value"].replace("<image>", "").strip()
elif turn["from"] == "gpt" or turn["from"] == "assistant":
gt_answer = turn["value"].strip()
break
if not question or not gt_answer:
return {
"image": str(image_path.absolute()),
"question": "dummy",
"gt_answer": "dummy",
"valid": False
}
return {
"image": str(image_path.absolute()),
"question": question,
"gt_answer": gt_answer,
"valid": True
}
# Process in parallel
processed_dataset = raw_dataset.map(
process_example_safe,
num_proc=num_workers,
remove_columns=["line", "image_root", "line_num"],
desc="Processing medical QA records"
)
# Filter out invalid entries
valid_dataset = processed_dataset.filter(lambda x: x["valid"])
# Remove the 'valid' column
valid_dataset = valid_dataset.remove_columns(["valid"])
print(f"Valid samples after processing: {len(valid_dataset)}")
# # Load images sequentially to manage memory
# def load_image_safe(example):
# image = Image.open(example["image"])
# if image.mode != 'RGB':
# image = image.convert('RGB')
# example["image"] = image
# example["image_loaded"] = True
# return example
# # Load images
# final_dataset = valid_dataset.map(
# load_image_safe,
# desc="Loading images",
# num_proc=256 # Sequential for image loading
# )
# # Filter out failed image loads
# final_dataset = valid_dataset.filter(lambda x: x["image_loaded"])
# final_dataset = final_dataset.remove_columns(["image_loaded"])
print(f"✅ Final dataset size: {len(valid_dataset)} medical QA samples")
return valid_dataset
if __name__ == "__main__":
# Test the loader
ds = jsonl_to_dataset_hf_parallel("data/train.jsonl")
if len(ds) > 0:
print("Sample:", ds[0].keys())
print("Question:", ds[0]["question"])
print("Answer:", ds[0]["gt_answer"])