Hafnium49 commited on
Commit
c2ed6db
·
verified ·
1 Parent(s): 05349ff

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. Dockerfile +40 -0
  2. README.md +30 -6
  3. app.py +370 -0
Dockerfile ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DGL Backend Dockerfile for MatGL HuggingFace Spaces
2
+ # Reference: https://huggingface.co/docs/hub/en/spaces-sdks-docker
3
+
4
+ FROM python:3.11-slim
5
+
6
+ WORKDIR /app
7
+
8
+ # Install system dependencies as root
9
+ RUN apt-get update && apt-get install -y --no-install-recommends \
10
+ build-essential \
11
+ libcurl4 \
12
+ && rm -rf /var/lib/apt/lists/*
13
+
14
+ # Install dependencies AS ROOT (before creating user)
15
+ # Exact versions from official matgl repo for compatibility
16
+ RUN pip install --no-cache-dir "numpy<2"
17
+ RUN pip install --no-cache-dir torch==2.3.0 --index-url https://download.pytorch.org/whl/cpu
18
+ RUN pip install --no-cache-dir "torchdata<=0.8.0"
19
+ RUN pip install --no-cache-dir dgl==2.2.0 -f https://data.dgl.ai/wheels/torch-2.3/repo.html
20
+ RUN pip install --no-cache-dir matgl pymatgen gradio
21
+
22
+ # Create non-root user AFTER installing packages
23
+ RUN useradd -m -u 1000 user
24
+
25
+ # Make /app writable by user (for model cache)
26
+ RUN chown -R user:user /app
27
+
28
+ # Switch to user for running
29
+ USER user
30
+ ENV HOME=/home/user \
31
+ PATH=/home/user/.local/bin:$PATH \
32
+ HF_HOME=/app/.cache \
33
+ MATGL_BACKEND=DGL
34
+
35
+ # Copy app
36
+ COPY --chown=user app.py /app/
37
+
38
+ EXPOSE 7860
39
+
40
+ CMD ["python", "app.py"]
README.md CHANGED
@@ -1,10 +1,34 @@
1
  ---
2
- title: Matgl So3net
3
- emoji: 💻
4
- colorFrom: purple
5
- colorTo: red
6
  sdk: docker
7
- pinned: false
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: MatGL SO3Net
3
+ emoji: 🔬
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: docker
7
+ app_port: 7860
8
  ---
9
 
10
+ # MatGL SO3Net
11
+
12
+ Material property prediction using **SO3Net-ANI-1x-Subset-PES** from the MatGL library.
13
+
14
+ ## API Endpoints
15
+
16
+ - **Energy**: Predict total and per-atom energy
17
+ - **Forces**: Predict atomic forces
18
+ - **Stress**: Predict stress tensor
19
+ - **Health**: Health check endpoint
20
+
21
+ ## Usage
22
+
23
+ Paste a CIF structure and click Predict to get material properties.
24
+
25
+ ## Model Information
26
+
27
+ - **Model**: `SO3Net-ANI-1x-Subset-PES`
28
+ - **Library**: [MatGL](https://github.com/materialsvirtuallab/matgl)
29
+ - **Backend**: DGL
30
+
31
+ ## Links
32
+
33
+ - [MatGL Documentation](https://matgl.ai/)
34
+ - [MKC-Holmes Project](https://github.com/materialsvirtuallab/matgl)
app.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MatGL PES Model Template (DGL Backend)
3
+ Supports: Energy, Forces, Stress calculation via PESCalculator
4
+
5
+ Replace MODEL_NAME with the actual model name, e.g.:
6
+ - M3GNet-MP-2021.2.8-PES
7
+ - CHGNet-MPtrj-2024.2.13-11M-PES
8
+ """
9
+ import os
10
+ import time
11
+
12
+ # Set DGL backend BEFORE importing matgl (critical for v2.0+)
13
+ os.environ["MATGL_BACKEND"] = "DGL"
14
+
15
+ import gradio as gr
16
+ import torch
17
+ import numpy as np
18
+ import dgl # Import DGL explicitly to ensure it's loaded first
19
+ import matgl
20
+
21
+ # Also set via API (belt and suspenders)
22
+ matgl.set_backend("DGL")
23
+
24
+ # === CONFIGURATION ===
25
+ MODEL_NAME = "SO3Net-ANI-1x-Subset-PES" # Replace with actual model name
26
+ SPACE_TITLE = "MatGL SO3Net" # Replace with space title
27
+ # =====================
28
+
29
+ # Global model cache
30
+ _model = None
31
+ _model_error = None
32
+
33
+
34
+ def get_model():
35
+ """Lazy load the model on first use."""
36
+ global _model, _model_error
37
+
38
+ if _model_error:
39
+ return None, _model_error
40
+
41
+ if _model is None:
42
+ try:
43
+ print(f"Loading {MODEL_NAME}...")
44
+ _model = matgl.load_model(MODEL_NAME)
45
+ print(f"{MODEL_NAME} loaded successfully.")
46
+ except Exception as e:
47
+ _model_error = str(e)
48
+ print(f"Failed to load model: {e}")
49
+ import traceback
50
+ traceback.print_exc()
51
+ return None, _model_error
52
+
53
+ return _model, None
54
+
55
+
56
+ def predict_energy(cif_string: str) -> dict:
57
+ """Predict energy per atom from CIF string."""
58
+ if not cif_string or not cif_string.strip():
59
+ return {"status": "error", "message": "CIF string is empty"}
60
+
61
+ pot, error = get_model()
62
+ if error:
63
+ return {"status": "error", "message": f"Model load failed: {error}"}
64
+
65
+ try:
66
+ start_time = time.time()
67
+ from pymatgen.core import Structure
68
+ from pymatgen.io.ase import AseAtomsAdaptor
69
+ from matgl.ext.ase import PESCalculator
70
+
71
+ struct = Structure.from_str(cif_string, fmt="cif")
72
+ adaptor = AseAtomsAdaptor()
73
+ atoms = adaptor.get_atoms(struct)
74
+
75
+ calc = PESCalculator(potential=pot)
76
+ atoms.calc = calc
77
+
78
+ e_total = atoms.get_potential_energy()
79
+ e_per_atom = e_total / len(atoms)
80
+ elapsed_ms = (time.time() - start_time) * 1000
81
+
82
+ return {
83
+ "status": "success",
84
+ "formula": struct.composition.reduced_formula,
85
+ "num_atoms": len(struct),
86
+ "model": MODEL_NAME,
87
+ "model_type": "pes",
88
+ "properties": {
89
+ "energy_per_atom": {"value": round(e_per_atom, 4), "unit": "eV/atom"},
90
+ "total_energy": {"value": round(e_total, 4), "unit": "eV"},
91
+ },
92
+ "metadata": {
93
+ "backend": "DGL",
94
+ "computation_time_ms": round(elapsed_ms, 1),
95
+ },
96
+ }
97
+ except Exception as e:
98
+ return {"status": "error", "message": str(e)}
99
+
100
+
101
+ def predict_forces(cif_string: str) -> dict:
102
+ """Predict forces from CIF string."""
103
+ if not cif_string or not cif_string.strip():
104
+ return {"status": "error", "message": "CIF string is empty"}
105
+
106
+ pot, error = get_model()
107
+ if error:
108
+ return {"status": "error", "message": f"Model load failed: {error}"}
109
+
110
+ try:
111
+ start_time = time.time()
112
+ from pymatgen.core import Structure
113
+ from pymatgen.io.ase import AseAtomsAdaptor
114
+ from matgl.ext.ase import PESCalculator
115
+
116
+ struct = Structure.from_str(cif_string, fmt="cif")
117
+ adaptor = AseAtomsAdaptor()
118
+ atoms = adaptor.get_atoms(struct)
119
+
120
+ calc = PESCalculator(potential=pot)
121
+ atoms.calc = calc
122
+
123
+ forces = atoms.get_forces()
124
+ elapsed_ms = (time.time() - start_time) * 1000
125
+
126
+ return {
127
+ "status": "success",
128
+ "formula": struct.composition.reduced_formula,
129
+ "num_atoms": len(struct),
130
+ "model": MODEL_NAME,
131
+ "model_type": "pes",
132
+ "properties": {
133
+ "forces": {
134
+ "value": forces.tolist(),
135
+ "unit": "eV/A",
136
+ "shape": list(forces.shape),
137
+ },
138
+ },
139
+ "metadata": {
140
+ "backend": "DGL",
141
+ "computation_time_ms": round(elapsed_ms, 1),
142
+ },
143
+ }
144
+ except Exception as e:
145
+ return {"status": "error", "message": str(e)}
146
+
147
+
148
+ def predict_stress(cif_string: str) -> dict:
149
+ """Predict stress tensor from CIF string."""
150
+ if not cif_string or not cif_string.strip():
151
+ return {"status": "error", "message": "CIF string is empty"}
152
+
153
+ pot, error = get_model()
154
+ if error:
155
+ return {"status": "error", "message": f"Model load failed: {error}"}
156
+
157
+ try:
158
+ start_time = time.time()
159
+ from pymatgen.core import Structure
160
+ from pymatgen.io.ase import AseAtomsAdaptor
161
+ from matgl.ext.ase import PESCalculator
162
+
163
+ struct = Structure.from_str(cif_string, fmt="cif")
164
+ adaptor = AseAtomsAdaptor()
165
+ atoms = adaptor.get_atoms(struct)
166
+
167
+ calc = PESCalculator(potential=pot)
168
+ atoms.calc = calc
169
+
170
+ # ASE returns stress in eV/A^3, convert to GPa
171
+ stress_voigt = atoms.get_stress() # Voigt notation [xx, yy, zz, yz, xz, xy]
172
+ stress_gpa = stress_voigt * 160.21766208 # eV/A^3 to GPa
173
+ elapsed_ms = (time.time() - start_time) * 1000
174
+
175
+ return {
176
+ "status": "success",
177
+ "formula": struct.composition.reduced_formula,
178
+ "num_atoms": len(struct),
179
+ "model": MODEL_NAME,
180
+ "model_type": "pes",
181
+ "properties": {
182
+ "stress_voigt": {
183
+ "value": stress_gpa.tolist(),
184
+ "unit": "GPa",
185
+ "order": ["xx", "yy", "zz", "yz", "xz", "xy"],
186
+ },
187
+ },
188
+ "metadata": {
189
+ "backend": "DGL",
190
+ "computation_time_ms": round(elapsed_ms, 1),
191
+ },
192
+ }
193
+ except Exception as e:
194
+ return {"status": "error", "message": str(e)}
195
+
196
+
197
+ def predict_all(cif_string: str) -> dict:
198
+ """Predict all properties (energy, forces, stress) from CIF string."""
199
+ if not cif_string or not cif_string.strip():
200
+ return {"status": "error", "message": "CIF string is empty"}
201
+
202
+ pot, error = get_model()
203
+ if error:
204
+ return {"status": "error", "message": f"Model load failed: {error}"}
205
+
206
+ try:
207
+ start_time = time.time()
208
+ from pymatgen.core import Structure
209
+ from pymatgen.io.ase import AseAtomsAdaptor
210
+ from matgl.ext.ase import PESCalculator
211
+
212
+ struct = Structure.from_str(cif_string, fmt="cif")
213
+ adaptor = AseAtomsAdaptor()
214
+ atoms = adaptor.get_atoms(struct)
215
+
216
+ calc = PESCalculator(potential=pot)
217
+ atoms.calc = calc
218
+
219
+ e_total = atoms.get_potential_energy()
220
+ e_per_atom = e_total / len(atoms)
221
+ forces = atoms.get_forces()
222
+ stress_voigt = atoms.get_stress()
223
+ stress_gpa = stress_voigt * 160.21766208
224
+ elapsed_ms = (time.time() - start_time) * 1000
225
+
226
+ return {
227
+ "status": "success",
228
+ "formula": struct.composition.reduced_formula,
229
+ "num_atoms": len(struct),
230
+ "model": MODEL_NAME,
231
+ "model_type": "pes",
232
+ "properties": {
233
+ "energy_per_atom": {"value": round(e_per_atom, 4), "unit": "eV/atom"},
234
+ "total_energy": {"value": round(e_total, 4), "unit": "eV"},
235
+ "forces": {
236
+ "value": forces.tolist(),
237
+ "unit": "eV/A",
238
+ "shape": list(forces.shape),
239
+ },
240
+ "stress_voigt": {
241
+ "value": stress_gpa.tolist(),
242
+ "unit": "GPa",
243
+ "order": ["xx", "yy", "zz", "yz", "xz", "xy"],
244
+ },
245
+ },
246
+ "metadata": {
247
+ "backend": "DGL",
248
+ "computation_time_ms": round(elapsed_ms, 1),
249
+ },
250
+ }
251
+ except Exception as e:
252
+ return {"status": "error", "message": str(e)}
253
+
254
+
255
+ def health_check() -> dict:
256
+ """Health check endpoint."""
257
+ pot, error = get_model()
258
+ if error:
259
+ return {"status": "error", "message": error}
260
+ return {"status": "healthy", "model": MODEL_NAME, "backend": "DGL"}
261
+
262
+
263
+ def model_info() -> dict:
264
+ """Return model information."""
265
+ return {
266
+ "model": MODEL_NAME,
267
+ "backend": "DGL",
268
+ "model_type": "pes",
269
+ "capabilities": ["energy", "forces", "stress"],
270
+ "units": {
271
+ "energy": "eV",
272
+ "energy_per_atom": "eV/atom",
273
+ "forces": "eV/A",
274
+ "stress": "GPa",
275
+ },
276
+ }
277
+
278
+
279
+ # Example NaCl CIF
280
+ EXAMPLE_CIF = """data_NaCl
281
+ _symmetry_space_group_name_H-M 'F m -3 m'
282
+ _cell_length_a 5.64
283
+ _cell_length_b 5.64
284
+ _cell_length_c 5.64
285
+ _cell_angle_alpha 90
286
+ _cell_angle_beta 90
287
+ _cell_angle_gamma 90
288
+ loop_
289
+ _atom_site_label
290
+ _atom_site_type_symbol
291
+ _atom_site_fract_x
292
+ _atom_site_fract_y
293
+ _atom_site_fract_z
294
+ Na1 Na 0.0 0.0 0.0
295
+ Cl1 Cl 0.5 0.5 0.5
296
+ """
297
+
298
+ # Create Gradio interface
299
+ with gr.Blocks() as demo:
300
+ gr.Markdown(f"# {SPACE_TITLE}")
301
+ gr.Markdown(f"PES model for energy, forces, and stress prediction using **{MODEL_NAME}**.")
302
+ gr.Markdown("**Note:** First prediction may take longer as the model loads on-demand.")
303
+
304
+ with gr.Tab("Energy"):
305
+ with gr.Row():
306
+ with gr.Column():
307
+ cif_energy = gr.Textbox(
308
+ label="CIF Structure",
309
+ placeholder="Paste CIF content here...",
310
+ lines=10,
311
+ value=EXAMPLE_CIF,
312
+ )
313
+ energy_btn = gr.Button("Predict Energy", variant="primary")
314
+ with gr.Column():
315
+ energy_output = gr.JSON(label="Result")
316
+ energy_btn.click(predict_energy, inputs=cif_energy, outputs=energy_output)
317
+
318
+ with gr.Tab("Forces"):
319
+ with gr.Row():
320
+ with gr.Column():
321
+ cif_forces = gr.Textbox(
322
+ label="CIF Structure",
323
+ placeholder="Paste CIF content here...",
324
+ lines=10,
325
+ value=EXAMPLE_CIF,
326
+ )
327
+ forces_btn = gr.Button("Predict Forces", variant="primary")
328
+ with gr.Column():
329
+ forces_output = gr.JSON(label="Result")
330
+ forces_btn.click(predict_forces, inputs=cif_forces, outputs=forces_output)
331
+
332
+ with gr.Tab("Stress"):
333
+ with gr.Row():
334
+ with gr.Column():
335
+ cif_stress = gr.Textbox(
336
+ label="CIF Structure",
337
+ placeholder="Paste CIF content here...",
338
+ lines=10,
339
+ value=EXAMPLE_CIF,
340
+ )
341
+ stress_btn = gr.Button("Predict Stress", variant="primary")
342
+ with gr.Column():
343
+ stress_output = gr.JSON(label="Result")
344
+ stress_btn.click(predict_stress, inputs=cif_stress, outputs=stress_output)
345
+
346
+ with gr.Tab("All Properties"):
347
+ with gr.Row():
348
+ with gr.Column():
349
+ cif_all = gr.Textbox(
350
+ label="CIF Structure",
351
+ placeholder="Paste CIF content here...",
352
+ lines=10,
353
+ value=EXAMPLE_CIF,
354
+ )
355
+ all_btn = gr.Button("Predict All", variant="primary")
356
+ with gr.Column():
357
+ all_output = gr.JSON(label="Result")
358
+ all_btn.click(predict_all, inputs=cif_all, outputs=all_output)
359
+
360
+ with gr.Tab("Info"):
361
+ info_btn = gr.Button("Get Model Info")
362
+ info_output = gr.JSON()
363
+ info_btn.click(model_info, outputs=info_output)
364
+
365
+ health_btn = gr.Button("Health Check")
366
+ health_output = gr.JSON()
367
+ health_btn.click(health_check, outputs=health_output)
368
+
369
+ if __name__ == "__main__":
370
+ demo.launch(server_name="0.0.0.0", server_port=7860)