Hafnium49's picture
Upload folder using huggingface_hub
2c4496c verified
"""
MatGL Phonon DOS Calculator HuggingFace Space
Uses Gradio Queue to enable long-running calculations (5-15 minutes).
Model: M3GNet-MP-2021.2.8-PES (Universal Potential)
Backend: DGL
"""
import os
import json
import time
import tempfile
import warnings
# Set DGL backend BEFORE importing matgl
os.environ["MATGL_BACKEND"] = "DGL"
import gradio as gr
import numpy as np
import matgl
from matgl.ext.ase import PESCalculator, Relaxer
from pymatgen.core import Structure
from pymatgen.io.ase import AseAtomsAdaptor
from ase.phonons import Phonons
from ase.build import make_supercell
# Suppress warnings (MatGL/DGL can be noisy)
warnings.filterwarnings("ignore")
# Configuration
MODEL_NAME = "M3GNet-MP-2021.2.8-PES"
SUPERCELL_MATRIX = [[2, 0, 0], [0, 2, 0], [0, 0, 2]]
DISPLACEMENT_DELTA = 0.01
RELAX_FMAX = 0.001 # eV/Angstrom
# Global model (loaded once at startup)
_model = None
def get_model():
"""Lazy load model once."""
global _model
if _model is None:
print(f"[Phonon Worker] Loading MatGL Model: {MODEL_NAME}...")
_model = matgl.load_model(MODEL_NAME)
print("[Phonon Worker] Model Loaded.")
return _model
def calculate_phonon_dos(cif_string: str, progress=gr.Progress(track_tqdm=True)) -> dict:
"""
Calculate phonon density of states from CIF structure.
This is a long-running calculation (5-15 minutes).
Gradio Queue keeps the connection alive.
Args:
cif_string: CIF format crystal structure
progress: Gradio progress tracker
Returns:
dict: Result with energies, DOS weights, and stability flag
"""
start_time = time.time()
try:
# Step 1: Parse Structure
progress(0.1, desc="Step 1/5: Parsing Structure...")
struct = Structure.from_str(cif_string, fmt="cif")
atoms = AseAtomsAdaptor.get_atoms(struct)
formula = struct.composition.reduced_formula
num_atoms = len(atoms)
print(f"[Phonon] Formula: {formula} ({num_atoms} atoms)")
# Step 2: High-Precision Relaxation
progress(0.2, desc="Step 2/5: High-Precision Relaxation...")
pot = get_model()
relaxer = Relaxer(potential=pot)
relax_result = relaxer.relax(atoms, fmax=RELAX_FMAX)
# relax_result["final_structure"] is a pymatgen Structure, convert to ASE Atoms
relaxed_structure = relax_result["final_structure"]
relaxed_atoms = AseAtomsAdaptor.get_atoms(relaxed_structure)
print(f"[Phonon] Relaxation complete.")
# Step 3: Generate Supercell
progress(0.3, desc="Step 3/5: Generating Supercell (2x2x2)...")
supercell = make_supercell(relaxed_atoms, SUPERCELL_MATRIX)
supercell_size = len(supercell)
print(f"[Phonon] Supercell: {supercell_size} atoms")
# Step 4: Finite Displacement (THE SLOW STEP)
progress(0.4, desc="Step 4/5: Calculating Forces (3N+1 displacements)...")
calc = PESCalculator(potential=pot)
with tempfile.TemporaryDirectory() as tmpdirname:
ph = Phonons(
supercell,
calc,
supercell=(1, 1, 1), # Already supercelled manually
delta=DISPLACEMENT_DELTA,
name=f"{tmpdirname}/phonon_cache"
)
# This is the heavy loop (3N+1 force calculations)
# For 8 atoms: 25 calculations. For 64 atoms (2x2x2): 193 calculations.
ph.run()
# Read force constants
ph.read(acoustic=True)
# Step 5: Compute DOS
progress(0.9, desc="Step 5/5: Computing Density of States...")
dos = ph.get_dos(kpts=(20, 20, 20)).sample_grid(npts=100, width=1e-3)
energies = dos.get_energies()
weights = dos.get_weights()
computation_time = round(time.time() - start_time, 2)
# Stability check: imaginary modes (negative frequencies) indicate instability
# We use -0.05 as threshold to allow for numerical noise
is_stable = bool(np.all(energies >= -0.05))
print(f"[Phonon] Complete in {computation_time}s. Stable: {is_stable}")
return {
"status": "success",
"formula": formula,
"num_atoms": num_atoms,
"supercell_atoms": supercell_size,
"energies_thz": energies.tolist(),
"dos_weights": weights.tolist(),
"is_stable_dynamic": is_stable,
"computation_time_sec": computation_time,
"model": MODEL_NAME,
"energy_range": {
"min": float(np.min(energies)),
"max": float(np.max(energies)),
"unit": "THz"
}
}
except Exception as e:
print(f"[Phonon] Error: {e}")
return {
"status": "error",
"message": str(e)
}
def health_check() -> dict:
"""Health check endpoint."""
try:
pot = get_model()
return {
"status": "ok",
"model": MODEL_NAME,
"backend": "DGL"
}
except Exception as e:
return {
"status": "error",
"message": str(e)
}
# Example CIF for testing
EXAMPLE_CIF = """data_NaCl
_symmetry_space_group_name_H-M 'F m -3 m'
_cell_length_a 5.64
_cell_length_b 5.64
_cell_length_c 5.64
_cell_angle_alpha 90.
_cell_angle_beta 90.
_cell_angle_gamma 90.
loop_
_atom_site_label
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
Na 0.00000 0.00000 0.00000
Cl 0.50000 0.50000 0.50000
"""
# Build Gradio Interface with Queue
with gr.Blocks(title="MatGL Phonon DOS Calculator") as demo:
gr.Markdown("""
# MatGL Phonon DOS Calculator
**Model:** M3GNet-MP-2021.2.8-PES (Universal Potential)
**Note:** Phonon calculations take **5-15 minutes** depending on crystal complexity.
The connection will be held open until completion.
## Usage
1. Paste your CIF structure below
2. Click "Calculate Phonons"
3. Wait for results (progress bar shows status)
## Output
- Phonon DOS (Density of States)
- Dynamic stability assessment (imaginary modes = unstable)
- Energy range in THz
""")
with gr.Row():
with gr.Column(scale=1):
cif_input = gr.Textbox(
label="CIF Structure",
placeholder="Paste CIF content here...",
lines=15,
value=EXAMPLE_CIF
)
calculate_btn = gr.Button("Calculate Phonons", variant="primary")
with gr.Column(scale=1):
output = gr.JSON(label="Phonon DOS Data")
calculate_btn.click(
fn=calculate_phonon_dos,
inputs=cif_input,
outputs=output,
api_name="/predict" # Expose API endpoint for gradio_client
)
# Health check endpoint (for monitoring)
with gr.Accordion("API Endpoints", open=False):
health_btn = gr.Button("Health Check")
health_output = gr.JSON(label="Health Status")
health_btn.click(fn=health_check, outputs=health_output)
# Enable Queueing for Long Timeouts
# This is the key feature that prevents 60s HTTP timeout
demo.queue(max_size=5).launch(server_name="0.0.0.0", server_port=7860)