johnpolat commited on
Commit
239b322
·
verified ·
1 Parent(s): 0812ca4

Upload crysmtm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. crysmtm.py +105 -0
crysmtm.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CrysMTM: A Multiphase, Temperature-Resolved, Multimodal Dataset for Crystalline Materials
3
+ """
4
+
5
+ import os
6
+ import pandas as pd
7
+ from datasets import Dataset, DatasetDict, Features, Value, Image, Sequence
8
+ from PIL import Image as PILImage
9
+
10
+ _CITATION = """\
11
+ @dataset{crysmtm2024,
12
+ title={CrysMTM: A Multiphase, Temperature-Resolved, Multimodal Dataset for Crystalline Materials},
13
+ author={Can Polat and Erchin Serpedin and Mustafa Kurban and Hasan Kurban},
14
+ year={2024},
15
+ url={https://github.com/KurbanIntelligenceLab/CrysMTM}
16
+ }
17
+ """
18
+
19
+ _DESCRIPTION = """\
20
+ CrysMTM is a comprehensive multiphase, temperature-resolved, multimodal dataset for crystalline materials research,
21
+ specifically focused on titanium dioxide (TiO₂) polymorphs. The dataset is designed primarily for regression tasks
22
+ to predict 9 key material properties from multimodal inputs.
23
+ """
24
+
25
+ _HOMEPAGE = "https://github.com/KurbanIntelligenceLab/CrysMTM"
26
+
27
+ _LICENSE = "cc-by-4.0"
28
+
29
+ def load_crysmtm_dataset(data_dir, split="train"):
30
+ """Load CrysMTM dataset for a specific split."""
31
+
32
+ # Load metadata
33
+ metadata_path = os.path.join(data_dir, "metadata", f"{split}_metadata.csv")
34
+ df = pd.read_csv(metadata_path)
35
+
36
+ def load_example(row):
37
+ """Load a single example with all modalities."""
38
+ example = {
39
+ "phase": row["phase"],
40
+ "temperature": row["temperature"],
41
+ "rotation": row["rotation"],
42
+ "split": row["split"]
43
+ }
44
+
45
+ # Load image
46
+ if pd.notna(row["image_path"]):
47
+ image_path = os.path.join(data_dir, row["image_path"])
48
+ if os.path.exists(image_path):
49
+ example["image"] = PILImage.open(image_path).convert("RGB")
50
+
51
+ # Load XYZ coordinates
52
+ if pd.notna(row["xyz_path"]):
53
+ xyz_path = os.path.join(data_dir, row["xyz_path"])
54
+ if os.path.exists(xyz_path):
55
+ with open(xyz_path, 'r') as f:
56
+ lines = f.readlines()[2:] # Skip header lines
57
+ coords = []
58
+ elements = []
59
+ for line in lines:
60
+ parts = line.strip().split()
61
+ if len(parts) >= 4:
62
+ elements.append(parts[0])
63
+ coords.append([float(x) for x in parts[1:4]])
64
+ example["xyz_coordinates"] = coords
65
+ example["elements"] = elements
66
+
67
+ # Load text
68
+ if pd.notna(row["text_path"]):
69
+ text_path = os.path.join(data_dir, row["text_path"])
70
+ if os.path.exists(text_path):
71
+ with open(text_path, 'r') as f:
72
+ example["text"] = f.read()
73
+
74
+ # Add regression labels
75
+ regression_properties = ["HOMO", "LUMO", "Eg", "Ef", "Et", "Eta", "disp", "vol", "bond"]
76
+ example["regression_labels"] = [row[prop] for prop in regression_properties]
77
+
78
+ # Add classification label
79
+ example["classification_label"] = row["label"]
80
+
81
+ return example
82
+
83
+ # Create dataset
84
+ dataset = Dataset.from_list([load_example(row) for _, row in df.iterrows()])
85
+
86
+ return dataset
87
+
88
+ def load_dataset(data_dir="."):
89
+ """Load the complete CrysMTM dataset."""
90
+
91
+ splits = ["train", "test_id", "test_ood"]
92
+ dataset_dict = {}
93
+
94
+ for split in splits:
95
+ try:
96
+ dataset_dict[split] = load_crysmtm_dataset(data_dir, split)
97
+ except FileNotFoundError:
98
+ print(f"Warning: {split} split not found")
99
+
100
+ return DatasetDict(dataset_dict)
101
+
102
+ # Main function for Hugging Face Hub
103
+ def load_crysmtm():
104
+ """Main function to load CrysMTM dataset."""
105
+ return load_dataset(".")