import os import pandas as pd from datasets import Dataset, DatasetDict, Features, Value, Image, Sequence from PIL import Image as PILImage def load_crysmtm_dataset(data_dir, split="train"): """Load CrysMTM dataset for a specific split.""" # Load metadata metadata_path = os.path.join(data_dir, "metadata", f"{split}_metadata.csv") df = pd.read_csv(metadata_path) def load_example(row): """Load a single example with all modalities.""" example = { "phase": row["phase"], "temperature": row["temperature"], "rotation": row["rotation"], "split": row["split"] } # Load image if pd.notna(row["image_path"]): image_path = os.path.join(data_dir, row["image_path"]) if os.path.exists(image_path): example["image"] = PILImage.open(image_path).convert("RGB") # Load XYZ coordinates if pd.notna(row["xyz_path"]): xyz_path = os.path.join(data_dir, row["xyz_path"]) if os.path.exists(xyz_path): with open(xyz_path, 'r') as f: lines = f.readlines()[2:] # Skip header lines coords = [] elements = [] for line in lines: parts = line.strip().split() if len(parts) >= 4: elements.append(parts[0]) coords.append([float(x) for x in parts[1:4]]) example["xyz_coordinates"] = coords example["elements"] = elements # Load text if pd.notna(row["text_path"]): text_path = os.path.join(data_dir, row["text_path"]) if os.path.exists(text_path): with open(text_path, 'r') as f: example["text"] = f.read() # Add regression labels regression_properties = ["HOMO", "LUMO", "Eg", "Ef", "Et", "Eta", "disp", "vol", "bond"] example["regression_labels"] = [row[prop] for prop in regression_properties] # Add classification label example["classification_label"] = row["label"] return example # Create dataset dataset = Dataset.from_list([load_example(row) for _, row in df.iterrows()]) return dataset def load_dataset(data_dir): """Load the complete CrysMTM dataset.""" splits = ["train", "test_id", "test_ood"] dataset_dict = {} for split in splits: try: dataset_dict[split] = load_crysmtm_dataset(data_dir, split) except FileNotFoundError: print(f"Warning: {split} split not found") return DatasetDict(dataset_dict) # This is the main function that Hugging Face Hub will call def load_crysmtm(): """Main function to load CrysMTM dataset.""" return load_dataset(".")