|
|
""" |
|
|
CrysMTM Dataset Loading Script |
|
|
|
|
|
To use this dataset: |
|
|
|
|
|
1. Download the dataset files from: https://huggingface.co/datasets/johnpolat/CrysMTM |
|
|
2. Place this script in the same directory as the downloaded files |
|
|
3. Run: python load_dataset.py |
|
|
|
|
|
Or use the Hugging Face datasets library directly: |
|
|
from datasets import load_dataset |
|
|
dataset = load_dataset("johnpolat/CrysMTM", use_auth_token=True) |
|
|
""" |
|
|
|
|
|
import os |
|
|
import pandas as pd |
|
|
from datasets import Dataset, DatasetDict |
|
|
from PIL import Image as PILImage |
|
|
|
|
|
def load_crysmtm_dataset(data_dir, split="train"): |
|
|
"""Load CrysMTM dataset for a specific split.""" |
|
|
|
|
|
|
|
|
metadata_path = os.path.join(data_dir, "metadata", f"{split}_metadata.csv") |
|
|
df = pd.read_csv(metadata_path) |
|
|
|
|
|
def load_example(row): |
|
|
"""Load a single example with all modalities.""" |
|
|
example = { |
|
|
"phase": row["phase"], |
|
|
"temperature": row["temperature"], |
|
|
"rotation": row["rotation"], |
|
|
"split": row["split"] |
|
|
} |
|
|
|
|
|
|
|
|
if pd.notna(row["image_path"]): |
|
|
image_path = os.path.join(data_dir, row["image_path"]) |
|
|
if os.path.exists(image_path): |
|
|
example["image"] = PILImage.open(image_path).convert("RGB") |
|
|
|
|
|
|
|
|
if pd.notna(row["xyz_path"]): |
|
|
xyz_path = os.path.join(data_dir, row["xyz_path"]) |
|
|
if os.path.exists(xyz_path): |
|
|
with open(xyz_path, 'r') as f: |
|
|
lines = f.readlines()[2:] |
|
|
coords = [] |
|
|
elements = [] |
|
|
for line in lines: |
|
|
parts = line.strip().split() |
|
|
if len(parts) >= 4: |
|
|
elements.append(parts[0]) |
|
|
coords.append([float(x) for x in parts[1:4]]) |
|
|
example["xyz_coordinates"] = coords |
|
|
example["elements"] = elements |
|
|
|
|
|
|
|
|
if pd.notna(row["text_path"]): |
|
|
text_path = os.path.join(data_dir, row["text_path"]) |
|
|
if os.path.exists(text_path): |
|
|
with open(text_path, 'r') as f: |
|
|
example["text"] = f.read() |
|
|
|
|
|
|
|
|
regression_properties = ["HOMO", "LUMO", "Eg", "Ef", "Et", "Eta", "disp", "vol", "bond"] |
|
|
example["regression_labels"] = [row[prop] for prop in regression_properties] |
|
|
|
|
|
|
|
|
example["classification_label"] = row["label"] |
|
|
|
|
|
return example |
|
|
|
|
|
|
|
|
dataset = Dataset.from_list([load_example(row) for _, row in df.iterrows()]) |
|
|
|
|
|
return dataset |
|
|
|
|
|
def load_dataset(data_dir="."): |
|
|
"""Load the complete CrysMTM dataset.""" |
|
|
|
|
|
splits = ["train", "test_id", "test_ood"] |
|
|
dataset_dict = {} |
|
|
|
|
|
for split in splits: |
|
|
try: |
|
|
dataset_dict[split] = load_crysmtm_dataset(data_dir, split) |
|
|
print(f"Loaded {split} split: {len(dataset_dict[split])} samples") |
|
|
except FileNotFoundError: |
|
|
print(f"Warning: {split} split not found") |
|
|
|
|
|
return DatasetDict(dataset_dict) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("Loading CrysMTM dataset...") |
|
|
dataset = load_dataset(".") |
|
|
|
|
|
print(f"\nDataset loaded successfully!") |
|
|
print(f"Available splits: {list(dataset.keys())}") |
|
|
|
|
|
|
|
|
if len(dataset) > 0: |
|
|
first_split = list(dataset.keys())[0] |
|
|
sample = dataset[first_split][0] |
|
|
print(f"\nSample from {first_split} split:") |
|
|
print(f" Phase: {sample['phase']}") |
|
|
print(f" Temperature: {sample['temperature']}K") |
|
|
print(f" Rotation: {sample['rotation']}") |
|
|
if 'image' in sample and sample['image'] is not None: |
|
|
print(f" Image size: {sample['image'].size}") |
|
|
if 'regression_labels' in sample: |
|
|
print(f" Regression labels: {sample['regression_labels']}") |
|
|
if 'classification_label' in sample: |
|
|
print(f" Classification label: {sample['classification_label']}") |
|
|
|
|
|
print("\n✅ Dataset ready to use!") |