File size: 9,621 Bytes
6d08d46 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 | """
Model inference logic for XRD pattern analysis.
Loads the pretrained model from HuggingFace Hub and runs predictions.
"""
import sys
from pathlib import Path
from typing import Dict, List, Optional
import numpy as np
import spglib
import torch
class XRDModelInference:
"""Handles loading and inference for the XRD analysis model"""
# Build a lookup table mapping space group number (1-230) to the
# corresponding Hall number. spglib.get_spacegroup_type() is indexed
# by Hall number (1-530), NOT by space group number. We pick the
# first (standard-setting) Hall number for each space group.
_sg_to_hall: Dict[int, int] = {}
for _hall in range(1, 531):
_sg_type = spglib.get_spacegroup_type(_hall)
_sg_num = _sg_type.number if hasattr(_sg_type, "number") else _sg_type["number"]
if _sg_num not in _sg_to_hall:
_sg_to_hall[_sg_num] = _hall
HF_REPO_ID = "linked-liszt/OpenAlphaDiffract"
def __init__(self):
self.model = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def is_loaded(self) -> bool:
"""Check if model is loaded"""
return self.model is not None
def load_model(self):
"""Download and load the pretrained model from HuggingFace Hub."""
try:
from huggingface_hub import snapshot_download
print(f"Downloading model from {self.HF_REPO_ID}...")
model_dir = snapshot_download(self.HF_REPO_ID)
print(f"Model downloaded to {model_dir}")
# Import the pure-PyTorch model class from the downloaded repo
sys.path.insert(0, model_dir)
from model import AlphaDiffract
self.model = AlphaDiffract.from_pretrained(
model_dir, device=str(self.device)
)
print(f"Model loaded successfully on {self.device}")
except Exception as e:
print(f"Error loading model: {e}")
import traceback
traceback.print_exc()
self.model = None
def preprocess_data(self, x: List[float], y: List[float]) -> torch.Tensor:
"""
Preprocess XRD data for model input.
Args:
x: 2theta values
y: Intensity values
Returns:
Preprocessed tensor ready for model input
"""
y_array = np.array(y, dtype=np.float32)
# Floor at zero (remove any negative values)
y_array = np.maximum(y_array, 0.0)
# Rescale intensity values to [0, 100] range (matching training preprocessing)
y_min = np.min(y_array)
y_max = np.max(y_array)
if y_max - y_min < 1e-10:
y_scaled = np.zeros_like(y_array, dtype=np.float32)
else:
y_normalized = (y_array - y_min) / (y_max - y_min)
y_scaled = y_normalized * 100.0
tensor = torch.from_numpy(y_scaled).unsqueeze(0)
return tensor.to(self.device)
def predict(self, x: List[float], y: List[float]) -> Dict:
"""
Run inference on XRD data.
Args:
x: 2theta values
y: Intensity values
Returns:
Dictionary with predictions and confidence scores
"""
if self.model is None:
return {
"status": "error",
"error": "Model not loaded.",
"http_status": 503,
}
try:
input_tensor = self.preprocess_data(x, y)
with torch.no_grad():
output = self.model(input_tensor)
processed = self._process_model_output(output)
overall_confidence = self._compute_overall_confidence(processed)
predictions = {
"status": "success",
"predictions": processed,
"model_info": {
"type": "AlphaDiffract",
"device": str(self.device),
},
}
if overall_confidence is not None:
predictions["confidence"] = overall_confidence
return predictions
except Exception as e:
return {
"status": "error",
"error": str(e),
"http_status": 500,
}
def _process_model_output(self, output) -> Dict:
"""Process raw model output into readable predictions"""
if isinstance(output, dict):
predictions = []
# Crystal System prediction (7 classes)
if "cs_logits" in output:
cs_logits = output["cs_logits"].cpu()
cs_probs = torch.softmax(cs_logits, dim=-1)
cs_prob, cs_idx = torch.max(cs_probs, dim=-1)
cs_names = [
"Triclinic", "Monoclinic", "Orthorhombic", "Tetragonal",
"Trigonal", "Hexagonal", "Cubic",
]
cs_all_probs = [
{
"class_name": cs_names[i],
"probability": float(cs_probs[0, i].item()),
}
for i in range(len(cs_names))
]
cs_all_probs.sort(key=lambda x: x["probability"], reverse=True)
predictions.append({
"phase": "Crystal System",
"predicted_class": cs_names[cs_idx.item()],
"confidence": float(cs_prob.item()),
"all_probabilities": cs_all_probs,
})
# Space Group prediction (230 classes)
if "sg_logits" in output:
sg_logits = output["sg_logits"].cpu()
sg_probs = torch.softmax(sg_logits, dim=-1)
sg_prob, sg_idx = torch.max(sg_probs, dim=-1)
sg_number = sg_idx.item() + 1
top_k = min(10, sg_probs.shape[-1])
top_probs, top_indices = torch.topk(sg_probs[0], top_k)
sg_top_probs = [
{
"space_group_number": int(idx.item()) + 1,
"space_group_symbol": self._get_space_group_symbol(int(idx.item()) + 1),
"probability": float(prob.item()),
}
for prob, idx in zip(top_probs, top_indices)
]
predictions.append({
"phase": "Space Group",
"predicted_class": f"#{sg_number}",
"space_group_symbol": self._get_space_group_symbol(sg_number),
"confidence": float(sg_prob.item()),
"top_probabilities": sg_top_probs,
})
# Lattice Parameters
if "lp" in output:
lp_raw = output["lp"].cpu()
if lp_raw.shape[0] == 1:
lp = lp_raw[0].numpy()
else:
lp = lp_raw.squeeze().numpy()
lp_labels = ["a", "b", "c", "\u03b1", "\u03b2", "\u03b3"]
predictions.append({
"phase": "Lattice Parameters",
"lattice_params": {
label: float(val) for label, val in zip(lp_labels, lp)
},
"is_lattice": True,
})
return {
"phase_predictions": predictions,
"intensity_profile": [],
}
elif isinstance(output, torch.Tensor):
probs = output.cpu().numpy()
confidence = None
if output.ndim >= 1 and output.shape[-1] > 1:
prob_tensor = torch.softmax(output, dim=-1)
confidence = float(prob_tensor.max().item())
predictions = [{"phase": "Predicted Phase", "details": f"Output shape: {probs.shape}"}]
if confidence is not None:
predictions[0]["confidence"] = confidence
return {
"phase_predictions": predictions,
"intensity_profile": probs.tolist() if len(probs.shape) == 1 else [],
}
return {"phase_predictions": [], "intensity_profile": []}
def _get_space_group_symbol(self, sg_number: int) -> str:
"""Get space group symbol from number using spglib."""
if sg_number < 1 or sg_number > 230:
return f"SG{sg_number}"
try:
hall_number = self._sg_to_hall.get(sg_number)
if hall_number is None:
return f"SG{sg_number}"
sg_type = spglib.get_spacegroup_type(hall_number)
if sg_type is not None:
symbol = (
sg_type.international_short
if hasattr(sg_type, "international_short")
else sg_type["international_short"]
)
return symbol
return f"SG{sg_number}"
except Exception:
return f"SG{sg_number}"
def _compute_overall_confidence(self, processed: Dict) -> Optional[float]:
"""Compute overall confidence from available per-phase confidences."""
phase_predictions = (
processed.get("phase_predictions", []) if isinstance(processed, dict) else []
)
confidences = [
float(p["confidence"])
for p in phase_predictions
if isinstance(p, dict) and "confidence" in p and p["confidence"] is not None
]
if not confidences:
return None
return float(np.mean(confidences))
|