CrysMTM / dataset_loading_script.py
johnpolat's picture
Upload dataset_loading_script.py with huggingface_hub
44052d1 verified
raw
history blame
2.96 kB
import os
import pandas as pd
from datasets import Dataset, DatasetDict, Features, Value, Image, Sequence
from PIL import Image as PILImage
def load_crysmtm_dataset(data_dir, split="train"):
"""Load CrysMTM dataset for a specific split."""
# Load metadata
metadata_path = os.path.join(data_dir, "metadata", f"{split}_metadata.csv")
df = pd.read_csv(metadata_path)
def load_example(row):
"""Load a single example with all modalities."""
example = {
"phase": row["phase"],
"temperature": row["temperature"],
"rotation": row["rotation"],
"split": row["split"]
}
# Load image
if pd.notna(row["image_path"]):
image_path = os.path.join(data_dir, row["image_path"])
if os.path.exists(image_path):
example["image"] = PILImage.open(image_path).convert("RGB")
# Load XYZ coordinates
if pd.notna(row["xyz_path"]):
xyz_path = os.path.join(data_dir, row["xyz_path"])
if os.path.exists(xyz_path):
with open(xyz_path, 'r') as f:
lines = f.readlines()[2:] # Skip header lines
coords = []
elements = []
for line in lines:
parts = line.strip().split()
if len(parts) >= 4:
elements.append(parts[0])
coords.append([float(x) for x in parts[1:4]])
example["xyz_coordinates"] = coords
example["elements"] = elements
# Load text
if pd.notna(row["text_path"]):
text_path = os.path.join(data_dir, row["text_path"])
if os.path.exists(text_path):
with open(text_path, 'r') as f:
example["text"] = f.read()
# Add regression labels
regression_properties = ["HOMO", "LUMO", "Eg", "Ef", "Et", "Eta", "disp", "vol", "bond"]
example["regression_labels"] = [row[prop] for prop in regression_properties]
# Add classification label
example["classification_label"] = row["label"]
return example
# Create dataset
dataset = Dataset.from_list([load_example(row) for _, row in df.iterrows()])
return dataset
def load_dataset(data_dir):
"""Load the complete CrysMTM dataset."""
splits = ["train", "test_id", "test_ood"]
dataset_dict = {}
for split in splits:
try:
dataset_dict[split] = load_crysmtm_dataset(data_dir, split)
except FileNotFoundError:
print(f"Warning: {split} split not found")
return DatasetDict(dataset_dict)
# This is the main function that Hugging Face Hub will call
def load_crysmtm():
"""Main function to load CrysMTM dataset."""
return load_dataset(".")