johnpolat commited on
Commit
89eab7d
·
verified ·
1 Parent(s): 239b322

Upload load_dataset.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. load_dataset.py +116 -0
load_dataset.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CrysMTM Dataset Loading Script
3
+
4
+ To use this dataset:
5
+
6
+ 1. Download the dataset files from: https://huggingface.co/datasets/johnpolat/CrysMTM
7
+ 2. Place this script in the same directory as the downloaded files
8
+ 3. Run: python load_dataset.py
9
+
10
+ Or use the Hugging Face datasets library directly:
11
+ from datasets import load_dataset
12
+ dataset = load_dataset("johnpolat/CrysMTM", use_auth_token=True)
13
+ """
14
+
15
+ import os
16
+ import pandas as pd
17
+ from datasets import Dataset, DatasetDict
18
+ from PIL import Image as PILImage
19
+
20
+ def load_crysmtm_dataset(data_dir, split="train"):
21
+ """Load CrysMTM dataset for a specific split."""
22
+
23
+ # Load metadata
24
+ metadata_path = os.path.join(data_dir, "metadata", f"{split}_metadata.csv")
25
+ df = pd.read_csv(metadata_path)
26
+
27
+ def load_example(row):
28
+ """Load a single example with all modalities."""
29
+ example = {
30
+ "phase": row["phase"],
31
+ "temperature": row["temperature"],
32
+ "rotation": row["rotation"],
33
+ "split": row["split"]
34
+ }
35
+
36
+ # Load image
37
+ if pd.notna(row["image_path"]):
38
+ image_path = os.path.join(data_dir, row["image_path"])
39
+ if os.path.exists(image_path):
40
+ example["image"] = PILImage.open(image_path).convert("RGB")
41
+
42
+ # Load XYZ coordinates
43
+ if pd.notna(row["xyz_path"]):
44
+ xyz_path = os.path.join(data_dir, row["xyz_path"])
45
+ if os.path.exists(xyz_path):
46
+ with open(xyz_path, 'r') as f:
47
+ lines = f.readlines()[2:] # Skip header lines
48
+ coords = []
49
+ elements = []
50
+ for line in lines:
51
+ parts = line.strip().split()
52
+ if len(parts) >= 4:
53
+ elements.append(parts[0])
54
+ coords.append([float(x) for x in parts[1:4]])
55
+ example["xyz_coordinates"] = coords
56
+ example["elements"] = elements
57
+
58
+ # Load text
59
+ if pd.notna(row["text_path"]):
60
+ text_path = os.path.join(data_dir, row["text_path"])
61
+ if os.path.exists(text_path):
62
+ with open(text_path, 'r') as f:
63
+ example["text"] = f.read()
64
+
65
+ # Add regression labels
66
+ regression_properties = ["HOMO", "LUMO", "Eg", "Ef", "Et", "Eta", "disp", "vol", "bond"]
67
+ example["regression_labels"] = [row[prop] for prop in regression_properties]
68
+
69
+ # Add classification label
70
+ example["classification_label"] = row["label"]
71
+
72
+ return example
73
+
74
+ # Create dataset
75
+ dataset = Dataset.from_list([load_example(row) for _, row in df.iterrows()])
76
+
77
+ return dataset
78
+
79
+ def load_dataset(data_dir="."):
80
+ """Load the complete CrysMTM dataset."""
81
+
82
+ splits = ["train", "test_id", "test_ood"]
83
+ dataset_dict = {}
84
+
85
+ for split in splits:
86
+ try:
87
+ dataset_dict[split] = load_crysmtm_dataset(data_dir, split)
88
+ print(f"Loaded {split} split: {len(dataset_dict[split])} samples")
89
+ except FileNotFoundError:
90
+ print(f"Warning: {split} split not found")
91
+
92
+ return DatasetDict(dataset_dict)
93
+
94
+ if __name__ == "__main__":
95
+ print("Loading CrysMTM dataset...")
96
+ dataset = load_dataset(".")
97
+
98
+ print(f"\nDataset loaded successfully!")
99
+ print(f"Available splits: {list(dataset.keys())}")
100
+
101
+ # Show sample data
102
+ if len(dataset) > 0:
103
+ first_split = list(dataset.keys())[0]
104
+ sample = dataset[first_split][0]
105
+ print(f"\nSample from {first_split} split:")
106
+ print(f" Phase: {sample['phase']}")
107
+ print(f" Temperature: {sample['temperature']}K")
108
+ print(f" Rotation: {sample['rotation']}")
109
+ if 'image' in sample and sample['image'] is not None:
110
+ print(f" Image size: {sample['image'].size}")
111
+ if 'regression_labels' in sample:
112
+ print(f" Regression labels: {sample['regression_labels']}")
113
+ if 'classification_label' in sample:
114
+ print(f" Classification label: {sample['classification_label']}")
115
+
116
+ print("\n✅ Dataset ready to use!")