# from datasets import load_dataset # raw_ds = load_dataset("simwit/omni-med-vqa-mini") # full_dataset = raw_ds["test"] # split = full_dataset.train_test_split(test_size=0.2, seed=42) # train_dataset = split["train"] # eval_dataset = split["test"] # print("โœ… SFT Dataset loaded:") # print(f" ๐Ÿ“š Train samples: {len(train_dataset)}") # print(f" ๐Ÿงช Eval samples: {len(eval_dataset)}") # print(f"\n๐Ÿ“ Single Sample: [IMAGE] {train_dataset[0]['question']} {train_dataset[0]['gt_answer']} {train_dataset[0]['image_path']} {list(train_dataset[0].keys())}") """ Convert jsonl with `image` and `conversations` into a HuggingFace Dataset that LFM2-VL expects. Each sample must contain: - image : str (absolute path or relative to repo root) - messages: List[Dict] # openai-style """ import json, datasets from pathlib import Path from typing import List, Dict import multiprocessing as mp from PIL import Image SYSTEM_MSG = "You are a helpful vision-language assistant." """ Convert jsonl with `image` and `conversations` into a HuggingFace Dataset that works with the medical sample format. """ import json, datasets from pathlib import Path from typing import List, Dict import multiprocessing as mp from PIL import Image def format_vlm_sample(sample): """Format a vlm sample into the expected message format.""" return [ { "role": "user", "content": [ {"type": "image", "image": sample["image"]}, {"type": "text", "text": sample["question"]}, ], }, {"role": "assistant", "content": [{"type": "text", "text": sample["gt_answer"]}]}, ] def jsonl_to_dataset_hf_parallel(jsonl_file: str, image_root: str = "", num_workers: int = None): """ Fixed parallel version that handles None values properly """ if num_workers is None: num_workers = 8 # Load and validate all lines first valid_lines = [] with open(jsonl_file, encoding="utf-8") as f: for line_num, line in enumerate(f): line = line.strip() if line: # Skip empty lines try: # Quick validation rec = json.loads(line) if "image" in rec and "conversations" in rec: valid_lines.append({"line": line, "image_root": image_root, "line_num": line_num}) except: print(f"Warning: Line {line_num}: Invalid JSON") continue print(f"Found {len(valid_lines)} valid lines to process") # Create dataset from valid lines raw_dataset = datasets.Dataset.from_list(valid_lines) def process_example_safe(example): """Process function that never returns None""" rec = json.loads(example["line"]) image_path = Path(example["image_root"]) / rec["image"] if not image_path.exists(): # Return a dummy valid entry instead of None return { "image": str(image_path.absolute()), "question": "dummy", "gt_answer": "dummy", "valid": False } # Extract question and answer question = "" gt_answer = "" for turn in rec["conversations"]: if turn["from"] == "human": question = turn["value"].replace("", "").strip() elif turn["from"] == "gpt" or turn["from"] == "assistant": gt_answer = turn["value"].strip() break if not question or not gt_answer: return { "image": str(image_path.absolute()), "question": "dummy", "gt_answer": "dummy", "valid": False } return { "image": str(image_path.absolute()), "question": question, "gt_answer": gt_answer, "valid": True } # Process in parallel processed_dataset = raw_dataset.map( process_example_safe, num_proc=num_workers, remove_columns=["line", "image_root", "line_num"], desc="Processing medical QA records" ) # Filter out invalid entries valid_dataset = processed_dataset.filter(lambda x: x["valid"]) # Remove the 'valid' column valid_dataset = valid_dataset.remove_columns(["valid"]) print(f"Valid samples after processing: {len(valid_dataset)}") # # Load images sequentially to manage memory # def load_image_safe(example): # image = Image.open(example["image"]) # if image.mode != 'RGB': # image = image.convert('RGB') # example["image"] = image # example["image_loaded"] = True # return example # # Load images # final_dataset = valid_dataset.map( # load_image_safe, # desc="Loading images", # num_proc=256 # Sequential for image loading # ) # # Filter out failed image loads # final_dataset = valid_dataset.filter(lambda x: x["image_loaded"]) # final_dataset = final_dataset.remove_columns(["image_loaded"]) print(f"โœ… Final dataset size: {len(valid_dataset)} medical QA samples") return valid_dataset if __name__ == "__main__": # Test the loader ds = jsonl_to_dataset_hf_parallel("data/train.jsonl") if len(ds) > 0: print("Sample:", ds[0].keys()) print("Question:", ds[0]["question"]) print("Answer:", ds[0]["gt_answer"])