File size: 2,958 Bytes
44052d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import pandas as pd
from datasets import Dataset, DatasetDict, Features, Value, Image, Sequence
from PIL import Image as PILImage

def load_crysmtm_dataset(data_dir, split="train"):
    """Load CrysMTM dataset for a specific split."""
    
    # Load metadata
    metadata_path = os.path.join(data_dir, "metadata", f"{split}_metadata.csv")
    df = pd.read_csv(metadata_path)
    
    def load_example(row):
        """Load a single example with all modalities."""
        example = {
            "phase": row["phase"],
            "temperature": row["temperature"],
            "rotation": row["rotation"],
            "split": row["split"]
        }
        
        # Load image
        if pd.notna(row["image_path"]):
            image_path = os.path.join(data_dir, row["image_path"])
            if os.path.exists(image_path):
                example["image"] = PILImage.open(image_path).convert("RGB")
        
        # Load XYZ coordinates
        if pd.notna(row["xyz_path"]):
            xyz_path = os.path.join(data_dir, row["xyz_path"])
            if os.path.exists(xyz_path):
                with open(xyz_path, 'r') as f:
                    lines = f.readlines()[2:]  # Skip header lines
                    coords = []
                    elements = []
                    for line in lines:
                        parts = line.strip().split()
                        if len(parts) >= 4:
                            elements.append(parts[0])
                            coords.append([float(x) for x in parts[1:4]])
                    example["xyz_coordinates"] = coords
                    example["elements"] = elements
        
        # Load text
        if pd.notna(row["text_path"]):
            text_path = os.path.join(data_dir, row["text_path"])
            if os.path.exists(text_path):
                with open(text_path, 'r') as f:
                    example["text"] = f.read()
        
        # Add regression labels
        regression_properties = ["HOMO", "LUMO", "Eg", "Ef", "Et", "Eta", "disp", "vol", "bond"]
        example["regression_labels"] = [row[prop] for prop in regression_properties]
        
        # Add classification label
        example["classification_label"] = row["label"]
        
        return example
    
    # Create dataset
    dataset = Dataset.from_list([load_example(row) for _, row in df.iterrows()])
    
    return dataset

def load_dataset(data_dir):
    """Load the complete CrysMTM dataset."""
    
    splits = ["train", "test_id", "test_ood"]
    dataset_dict = {}
    
    for split in splits:
        try:
            dataset_dict[split] = load_crysmtm_dataset(data_dir, split)
        except FileNotFoundError:
            print(f"Warning: {split} split not found")
    
    return DatasetDict(dataset_dict)

# This is the main function that Hugging Face Hub will call
def load_crysmtm():
    """Main function to load CrysMTM dataset."""
    return load_dataset(".")