File size: 4,204 Bytes
89eab7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""
CrysMTM Dataset Loading Script

To use this dataset:

1. Download the dataset files from: https://huggingface.co/datasets/johnpolat/CrysMTM
2. Place this script in the same directory as the downloaded files
3. Run: python load_dataset.py

Or use the Hugging Face datasets library directly:
from datasets import load_dataset
dataset = load_dataset("johnpolat/CrysMTM", use_auth_token=True)
"""

import os
import pandas as pd
from datasets import Dataset, DatasetDict
from PIL import Image as PILImage

def load_crysmtm_dataset(data_dir, split="train"):
    """Load CrysMTM dataset for a specific split."""
    
    # Load metadata
    metadata_path = os.path.join(data_dir, "metadata", f"{split}_metadata.csv")
    df = pd.read_csv(metadata_path)
    
    def load_example(row):
        """Load a single example with all modalities."""
        example = {
            "phase": row["phase"],
            "temperature": row["temperature"],
            "rotation": row["rotation"],
            "split": row["split"]
        }
        
        # Load image
        if pd.notna(row["image_path"]):
            image_path = os.path.join(data_dir, row["image_path"])
            if os.path.exists(image_path):
                example["image"] = PILImage.open(image_path).convert("RGB")
        
        # Load XYZ coordinates
        if pd.notna(row["xyz_path"]):
            xyz_path = os.path.join(data_dir, row["xyz_path"])
            if os.path.exists(xyz_path):
                with open(xyz_path, 'r') as f:
                    lines = f.readlines()[2:]  # Skip header lines
                    coords = []
                    elements = []
                    for line in lines:
                        parts = line.strip().split()
                        if len(parts) >= 4:
                            elements.append(parts[0])
                            coords.append([float(x) for x in parts[1:4]])
                    example["xyz_coordinates"] = coords
                    example["elements"] = elements
        
        # Load text
        if pd.notna(row["text_path"]):
            text_path = os.path.join(data_dir, row["text_path"])
            if os.path.exists(text_path):
                with open(text_path, 'r') as f:
                    example["text"] = f.read()
        
        # Add regression labels
        regression_properties = ["HOMO", "LUMO", "Eg", "Ef", "Et", "Eta", "disp", "vol", "bond"]
        example["regression_labels"] = [row[prop] for prop in regression_properties]
        
        # Add classification label
        example["classification_label"] = row["label"]
        
        return example
    
    # Create dataset
    dataset = Dataset.from_list([load_example(row) for _, row in df.iterrows()])
    
    return dataset

def load_dataset(data_dir="."):
    """Load the complete CrysMTM dataset."""
    
    splits = ["train", "test_id", "test_ood"]
    dataset_dict = {}
    
    for split in splits:
        try:
            dataset_dict[split] = load_crysmtm_dataset(data_dir, split)
            print(f"Loaded {split} split: {len(dataset_dict[split])} samples")
        except FileNotFoundError:
            print(f"Warning: {split} split not found")
    
    return DatasetDict(dataset_dict)

if __name__ == "__main__":
    print("Loading CrysMTM dataset...")
    dataset = load_dataset(".")
    
    print(f"\nDataset loaded successfully!")
    print(f"Available splits: {list(dataset.keys())}")
    
    # Show sample data
    if len(dataset) > 0:
        first_split = list(dataset.keys())[0]
        sample = dataset[first_split][0]
        print(f"\nSample from {first_split} split:")
        print(f"  Phase: {sample['phase']}")
        print(f"  Temperature: {sample['temperature']}K")
        print(f"  Rotation: {sample['rotation']}")
        if 'image' in sample and sample['image'] is not None:
            print(f"  Image size: {sample['image'].size}")
        if 'regression_labels' in sample:
            print(f"  Regression labels: {sample['regression_labels']}")
        if 'classification_label' in sample:
            print(f"  Classification label: {sample['classification_label']}")
    
    print("\n✅ Dataset ready to use!")