|
|
import os |
|
|
import pandas as pd |
|
|
from datasets import Dataset, DatasetDict, Features, Value, Image, Sequence |
|
|
from PIL import Image as PILImage |
|
|
|
|
|
def load_crysmtm_dataset(data_dir, split="train"): |
|
|
"""Load CrysMTM dataset for a specific split.""" |
|
|
|
|
|
|
|
|
metadata_path = os.path.join(data_dir, "metadata", f"{split}_metadata.csv") |
|
|
df = pd.read_csv(metadata_path) |
|
|
|
|
|
def load_example(row): |
|
|
"""Load a single example with all modalities.""" |
|
|
example = { |
|
|
"phase": row["phase"], |
|
|
"temperature": row["temperature"], |
|
|
"rotation": row["rotation"], |
|
|
"split": row["split"] |
|
|
} |
|
|
|
|
|
|
|
|
if pd.notna(row["image_path"]): |
|
|
image_path = os.path.join(data_dir, row["image_path"]) |
|
|
if os.path.exists(image_path): |
|
|
example["image"] = PILImage.open(image_path).convert("RGB") |
|
|
|
|
|
|
|
|
if pd.notna(row["xyz_path"]): |
|
|
xyz_path = os.path.join(data_dir, row["xyz_path"]) |
|
|
if os.path.exists(xyz_path): |
|
|
with open(xyz_path, 'r') as f: |
|
|
lines = f.readlines()[2:] |
|
|
coords = [] |
|
|
elements = [] |
|
|
for line in lines: |
|
|
parts = line.strip().split() |
|
|
if len(parts) >= 4: |
|
|
elements.append(parts[0]) |
|
|
coords.append([float(x) for x in parts[1:4]]) |
|
|
example["xyz_coordinates"] = coords |
|
|
example["elements"] = elements |
|
|
|
|
|
|
|
|
if pd.notna(row["text_path"]): |
|
|
text_path = os.path.join(data_dir, row["text_path"]) |
|
|
if os.path.exists(text_path): |
|
|
with open(text_path, 'r') as f: |
|
|
example["text"] = f.read() |
|
|
|
|
|
|
|
|
regression_properties = ["HOMO", "LUMO", "Eg", "Ef", "Et", "Eta", "disp", "vol", "bond"] |
|
|
example["regression_labels"] = [row[prop] for prop in regression_properties] |
|
|
|
|
|
|
|
|
example["classification_label"] = row["label"] |
|
|
|
|
|
return example |
|
|
|
|
|
|
|
|
dataset = Dataset.from_list([load_example(row) for _, row in df.iterrows()]) |
|
|
|
|
|
return dataset |
|
|
|
|
|
def load_dataset(data_dir): |
|
|
"""Load the complete CrysMTM dataset.""" |
|
|
|
|
|
splits = ["train", "test_id", "test_ood"] |
|
|
dataset_dict = {} |
|
|
|
|
|
for split in splits: |
|
|
try: |
|
|
dataset_dict[split] = load_crysmtm_dataset(data_dir, split) |
|
|
except FileNotFoundError: |
|
|
print(f"Warning: {split} split not found") |
|
|
|
|
|
return DatasetDict(dataset_dict) |
|
|
|
|
|
|
|
|
def load_crysmtm(): |
|
|
"""Main function to load CrysMTM dataset.""" |
|
|
return load_dataset(".") |