File size: 2,958 Bytes
44052d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import os
import pandas as pd
from datasets import Dataset, DatasetDict, Features, Value, Image, Sequence
from PIL import Image as PILImage
def load_crysmtm_dataset(data_dir, split="train"):
"""Load CrysMTM dataset for a specific split."""
# Load metadata
metadata_path = os.path.join(data_dir, "metadata", f"{split}_metadata.csv")
df = pd.read_csv(metadata_path)
def load_example(row):
"""Load a single example with all modalities."""
example = {
"phase": row["phase"],
"temperature": row["temperature"],
"rotation": row["rotation"],
"split": row["split"]
}
# Load image
if pd.notna(row["image_path"]):
image_path = os.path.join(data_dir, row["image_path"])
if os.path.exists(image_path):
example["image"] = PILImage.open(image_path).convert("RGB")
# Load XYZ coordinates
if pd.notna(row["xyz_path"]):
xyz_path = os.path.join(data_dir, row["xyz_path"])
if os.path.exists(xyz_path):
with open(xyz_path, 'r') as f:
lines = f.readlines()[2:] # Skip header lines
coords = []
elements = []
for line in lines:
parts = line.strip().split()
if len(parts) >= 4:
elements.append(parts[0])
coords.append([float(x) for x in parts[1:4]])
example["xyz_coordinates"] = coords
example["elements"] = elements
# Load text
if pd.notna(row["text_path"]):
text_path = os.path.join(data_dir, row["text_path"])
if os.path.exists(text_path):
with open(text_path, 'r') as f:
example["text"] = f.read()
# Add regression labels
regression_properties = ["HOMO", "LUMO", "Eg", "Ef", "Et", "Eta", "disp", "vol", "bond"]
example["regression_labels"] = [row[prop] for prop in regression_properties]
# Add classification label
example["classification_label"] = row["label"]
return example
# Create dataset
dataset = Dataset.from_list([load_example(row) for _, row in df.iterrows()])
return dataset
def load_dataset(data_dir):
"""Load the complete CrysMTM dataset."""
splits = ["train", "test_id", "test_ood"]
dataset_dict = {}
for split in splits:
try:
dataset_dict[split] = load_crysmtm_dataset(data_dir, split)
except FileNotFoundError:
print(f"Warning: {split} split not found")
return DatasetDict(dataset_dict)
# This is the main function that Hugging Face Hub will call
def load_crysmtm():
"""Main function to load CrysMTM dataset."""
return load_dataset(".") |