""" CrysMTM Dataset Loading Script To use this dataset: 1. Download the dataset files from: https://huggingface.co/datasets/johnpolat/CrysMTM 2. Place this script in the same directory as the downloaded files 3. Run: python load_dataset.py Or use the Hugging Face datasets library directly: from datasets import load_dataset dataset = load_dataset("johnpolat/CrysMTM", use_auth_token=True) """ import os import pandas as pd from datasets import Dataset, DatasetDict from PIL import Image as PILImage def load_crysmtm_dataset(data_dir, split="train"): """Load CrysMTM dataset for a specific split.""" # Load metadata metadata_path = os.path.join(data_dir, "metadata", f"{split}_metadata.csv") df = pd.read_csv(metadata_path) def load_example(row): """Load a single example with all modalities.""" example = { "phase": row["phase"], "temperature": row["temperature"], "rotation": row["rotation"], "split": row["split"] } # Load image if pd.notna(row["image_path"]): image_path = os.path.join(data_dir, row["image_path"]) if os.path.exists(image_path): example["image"] = PILImage.open(image_path).convert("RGB") # Load XYZ coordinates if pd.notna(row["xyz_path"]): xyz_path = os.path.join(data_dir, row["xyz_path"]) if os.path.exists(xyz_path): with open(xyz_path, 'r') as f: lines = f.readlines()[2:] # Skip header lines coords = [] elements = [] for line in lines: parts = line.strip().split() if len(parts) >= 4: elements.append(parts[0]) coords.append([float(x) for x in parts[1:4]]) example["xyz_coordinates"] = coords example["elements"] = elements # Load text if pd.notna(row["text_path"]): text_path = os.path.join(data_dir, row["text_path"]) if os.path.exists(text_path): with open(text_path, 'r') as f: example["text"] = f.read() # Add regression labels regression_properties = ["HOMO", "LUMO", "Eg", "Ef", "Et", "Eta", "disp", "vol", "bond"] example["regression_labels"] = [row[prop] for prop in regression_properties] # Add classification label example["classification_label"] = row["label"] return example # Create dataset dataset = Dataset.from_list([load_example(row) for _, row in df.iterrows()]) return dataset def load_dataset(data_dir="."): """Load the complete CrysMTM dataset.""" splits = ["train", "test_id", "test_ood"] dataset_dict = {} for split in splits: try: dataset_dict[split] = load_crysmtm_dataset(data_dir, split) print(f"Loaded {split} split: {len(dataset_dict[split])} samples") except FileNotFoundError: print(f"Warning: {split} split not found") return DatasetDict(dataset_dict) if __name__ == "__main__": print("Loading CrysMTM dataset...") dataset = load_dataset(".") print(f"\nDataset loaded successfully!") print(f"Available splits: {list(dataset.keys())}") # Show sample data if len(dataset) > 0: first_split = list(dataset.keys())[0] sample = dataset[first_split][0] print(f"\nSample from {first_split} split:") print(f" Phase: {sample['phase']}") print(f" Temperature: {sample['temperature']}K") print(f" Rotation: {sample['rotation']}") if 'image' in sample and sample['image'] is not None: print(f" Image size: {sample['image'].size}") if 'regression_labels' in sample: print(f" Regression labels: {sample['regression_labels']}") if 'classification_label' in sample: print(f" Classification label: {sample['classification_label']}") print("\n✅ Dataset ready to use!")