linked-liszt's picture
Upload folder using huggingface_hub
6d08d46 verified
import React, { createContext, useContext, useState, useMemo, useCallback } from 'react'
const XRDContext = createContext()
export const useXRD = () => {
const context = useContext(XRDContext)
if (!context) {
throw new Error('useXRD must be used within XRDProvider')
}
return context
}
export const XRDProvider = ({ children }) => {
// Model training specifications (from simulator.yaml)
const MODEL_INPUT_SIZE = 8192
const MODEL_WAVELENGTH = 0.6199 // Ångströms (synchrotron)
const MODEL_MIN_2THETA = 5.0 // degrees
const MODEL_MAX_2THETA = 20.0 // degrees
// Raw data from file upload
const [rawData, setRawData] = useState(null)
const [filename, setFilename] = useState(null)
// Wavelength management
const [detectedWavelength, setDetectedWavelength] = useState(null)
const [userWavelength, setUserWavelength] = useState(MODEL_WAVELENGTH)
const [wavelengthSource, setWavelengthSource] = useState('default') // 'detected', 'user', 'default'
// Processing parameters
const [baselineCorrection, setBaselineCorrection] = useState(false)
const [interpolationEnabled, setInterpolationEnabled] = useState(true)
const [scalingEnabled, setScalingEnabled] = useState(true)
const [interpolationStrategy, setInterpolationStrategy] = useState('linear') // 'linear' or 'cubic'
// Warnings and metadata
const [dataWarnings, setDataWarnings] = useState([])
// Model results from API
const [modelResults, setModelResults] = useState(null)
const [isLoading, setIsLoading] = useState(false)
const [analysisStatus, setAnalysisStatus] = useState('IDLE') // IDLE, PROCESSING, COMPLETE
// UI state
const [isLogitDrawerOpen, setIsLogitDrawerOpen] = useState(false)
// Request tracking - ensure every click creates a new request
const [analysisCount, setAnalysisCount] = useState(0)
// Convert wavelength using Bragg's law: λ = 2d·sin(θ)
// For same d-spacing: sin(θ₂) = (λ₂/λ₁)·sin(θ₁)
const convertWavelength = (theta_deg, sourceWavelength, targetWavelength) => {
if (Math.abs(sourceWavelength - targetWavelength) < 0.0001) {
return theta_deg // No conversion needed
}
const theta_rad = (theta_deg * Math.PI) / 180
const sin_theta2 = (targetWavelength / sourceWavelength) * Math.sin(theta_rad)
// Check if conversion is physically possible
if (Math.abs(sin_theta2) > 1) {
return null // Peak not observable at target wavelength
}
const theta2_rad = Math.asin(sin_theta2)
return (theta2_rad * 180) / Math.PI
}
// Interpolate data to fixed size for model input
const interpolateData = (x, y, targetSize, xMin, xMax, strategy = 'linear') => {
if (x.length === targetSize && xMin === undefined) {
return { x, y }
}
const minX = xMin !== undefined ? xMin : Math.min(...x)
const maxX = xMax !== undefined ? xMax : Math.max(...x)
const step = (maxX - minX) / (targetSize - 1)
const newX = Array.from({ length: targetSize }, (_, i) => minX + i * step)
const newY = new Array(targetSize)
// Get data range bounds
const dataMinX = Math.min(...x)
const dataMaxX = Math.max(...x)
if (strategy === 'linear') {
// Linear interpolation
for (let i = 0; i < targetSize; i++) {
const targetX = newX[i]
// Check if out of range - set to 0 instead of extrapolating
if (targetX < dataMinX || targetX > dataMaxX) {
newY[i] = 0
continue
}
// Find surrounding points
let idx = x.findIndex(val => val >= targetX)
if (idx === -1) idx = x.length - 1
if (idx === 0) idx = 1
const x0 = x[idx - 1]
const x1 = x[idx]
const y0 = y[idx - 1]
const y1 = y[idx]
// Linear interpolation
newY[i] = y0 + ((targetX - x0) * (y1 - y0)) / (x1 - x0)
}
} else if (strategy === 'cubic') {
// Cubic spline interpolation (simplified Catmull-Rom)
for (let i = 0; i < targetSize; i++) {
const targetX = newX[i]
// Check if out of range - set to 0 instead of extrapolating
if (targetX < dataMinX || targetX > dataMaxX) {
newY[i] = 0
continue
}
// Find surrounding points
let idx = x.findIndex(val => val >= targetX)
if (idx === -1) idx = x.length - 1
if (idx === 0) idx = 1
// Get 4 points for cubic interpolation
const i0 = Math.max(0, idx - 2)
const i1 = Math.max(0, idx - 1)
const i2 = Math.min(x.length - 1, idx)
const i3 = Math.min(x.length - 1, idx + 1)
// Use linear interpolation if we don't have enough points
if (i2 === i1) {
newY[i] = y[i1]
} else {
const t = (targetX - x[i1]) / (x[i2] - x[i1])
const t2 = t * t
const t3 = t2 * t
// Catmull-Rom spline coefficients
const v0 = y[i0]
const v1 = y[i1]
const v2 = y[i2]
const v3 = y[i3]
newY[i] = 0.5 * (
2 * v1 +
(-v0 + v2) * t +
(2 * v0 - 5 * v1 + 4 * v2 - v3) * t2 +
(-v0 + 3 * v1 - 3 * v2 + v3) * t3
)
}
}
}
return { x: newX, y: newY }
}
// Process data with optional interpolation
const processedData = useMemo(() => {
if (!rawData) return null
try {
const warnings = []
let processedY = [...rawData.y]
let processedX = [...rawData.x]
// Step 1: Wavelength conversion (if needed)
const sourceWavelength = userWavelength
if (sourceWavelength && Math.abs(sourceWavelength - MODEL_WAVELENGTH) > 0.0001) {
const convertedData = []
for (let i = 0; i < processedX.length; i++) {
const convertedTheta = convertWavelength(processedX[i], sourceWavelength, MODEL_WAVELENGTH)
if (convertedTheta !== null) {
convertedData.push({ x: convertedTheta, y: processedY[i] })
}
}
if (convertedData.length < processedX.length) {
warnings.push(`${processedX.length - convertedData.length} points outside physical range after wavelength conversion`)
}
processedX = convertedData.map(d => d.x)
processedY = convertedData.map(d => d.y)
warnings.push(`Converted from ${sourceWavelength.toFixed(4)} Å to ${MODEL_WAVELENGTH} Å`)
}
// Step 2: Apply baseline correction if enabled
if (baselineCorrection) {
const baseline = Math.min(...processedY)
processedY = processedY.map(val => val - baseline)
}
// Step 3: Crop to model's 2θ range (5-20°)
const inRangeData = []
for (let i = 0; i < processedX.length; i++) {
if (processedX[i] >= MODEL_MIN_2THETA && processedX[i] <= MODEL_MAX_2THETA) {
inRangeData.push({ x: processedX[i], y: processedY[i] })
}
}
if (inRangeData.length === 0) {
warnings.push(`⚠️ No data points in model range (${MODEL_MIN_2THETA}-${MODEL_MAX_2THETA}°)`)
// Use original data but warn
inRangeData.push(...processedX.map((x, i) => ({ x, y: processedY[i] })))
} else if (inRangeData.length < processedX.length) {
const coverage = (inRangeData.length / processedX.length * 100).toFixed(1)
warnings.push(`${coverage}% of data in model range (${MODEL_MIN_2THETA}-${MODEL_MAX_2THETA}°)`)
}
let croppedX = inRangeData.map(d => d.x)
let croppedY = inRangeData.map(d => d.y)
// Step 4: Apply 0-100 scaling if enabled (matching training data)
// NOTE: Scaling happens AFTER cropping so the max peak in the visible range = 100
if (scalingEnabled) {
const minY = Math.min(...croppedY)
const maxY = Math.max(...croppedY)
if (maxY - minY > 0) {
croppedY = croppedY.map(val => ((val - minY) / (maxY - minY)) * 100)
}
}
// Step 5: Interpolate to model input size with fixed range
const interpolated = interpolateData(
croppedX,
croppedY,
MODEL_INPUT_SIZE,
MODEL_MIN_2THETA,
MODEL_MAX_2THETA,
interpolationStrategy
)
// Update warnings
setDataWarnings(warnings)
return {
x: interpolated.x,
y: interpolated.y
}
} catch (error) {
console.error('Error processing data:', error)
setDataWarnings([`Error: ${error.message}`])
return rawData
}
}, [rawData, baselineCorrection, userWavelength, interpolationStrategy, scalingEnabled])
// Extract metadata from CIF/DIF files
const extractMetadata = (text) => {
const metadata = {
wavelength: null,
cellParams: null,
spaceGroup: null,
crystalSystem: null
}
const lines = text.split('\n')
// Common wavelength patterns in headers
const wavelengthPatterns = [
/wavelength[:\s=]+([0-9.]+)/i,
/lambda[:\s=]+([0-9.]+)/i,
/wave[:\s=]+([0-9.]+)/i,
/_pd_wavelength[:\s]+([0-9.]+)/i, // CIF format
/_diffrn_radiation_wavelength[:\s]+([0-9.]+)/i, // CIF format
/radiation.*?([0-9.]+)\s*[AÅ]/i,
]
for (const line of lines) {
// Extract wavelength
if (!metadata.wavelength) {
for (const pattern of wavelengthPatterns) {
const match = line.match(pattern)
if (match && match[1]) {
const wavelength = parseFloat(match[1])
if (wavelength > 0.1 && wavelength < 3.0) { // Reasonable X-ray range
metadata.wavelength = wavelength
break
}
}
}
// Check for common radiation types
if (/Cu\s*K[αa]/i.test(line)) metadata.wavelength = 1.5406 // Cu Kα
else if (/Mo\s*K[αa]/i.test(line)) metadata.wavelength = 0.7107 // Mo Kα
else if (/Co\s*K[αa]/i.test(line)) metadata.wavelength = 1.7889 // Co Kα
else if (/Cr\s*K[αa]/i.test(line)) metadata.wavelength = 2.2897 // Cr Kα
}
// Extract cell parameters (DIF format)
if (/CELL PARAMETERS:/i.test(line)) {
const match = line.match(/CELL PARAMETERS:\s*([\d.\s]+)/)
if (match) {
metadata.cellParams = match[1].trim()
}
}
// Extract space group
if (/SPACE GROUP:/i.test(line) || /_symmetry_Int_Tables_number/i.test(line)) {
const match = line.match(/(?:SPACE GROUP:|_symmetry_Int_Tables_number)[:\s]+(\d+)/)
if (match) {
metadata.spaceGroup = match[1]
}
}
// Extract crystal system
if (/Crystal System:/i.test(line)) {
const match = line.match(/Crystal System:\s*(\d+)/)
if (match) {
metadata.crystalSystem = match[1]
}
}
}
return metadata
}
// Parse CIF format data
const parseCIF = (text) => {
const lines = text.split('\n')
const x = []
const y = []
let inDataLoop = false
let dataColumns = []
let thetaIndex = -1
let intensityIndex = -1
for (let i = 0; i < lines.length; i++) {
const line = lines[i].trim()
// Detect start of data loop
if (line === 'loop_') {
inDataLoop = true
dataColumns = []
continue
}
// Collect column names in loop
if (inDataLoop && line.startsWith('_')) {
dataColumns.push(line)
// Identify 2theta column
if (/_pd_meas_angle_2theta/i.test(line) || /_pd_calc_angle_2theta/i.test(line)) {
thetaIndex = dataColumns.length - 1
}
// Identify intensity column
if (/_pd_proc_intensity/i.test(line) || /_pd_calc_intensity/i.test(line) || /_pd_meas_counts/i.test(line)) {
intensityIndex = dataColumns.length - 1
}
continue
}
// Parse data lines
if (inDataLoop && !line.startsWith('_') && !line.startsWith('loop_') && line.length > 0 && !line.startsWith('#')) {
// Check if we've found the data section
if (thetaIndex >= 0 && intensityIndex >= 0) {
const parts = line.split(/\s+/)
if (parts.length >= Math.max(thetaIndex, intensityIndex) + 1) {
const xVal = parseFloat(parts[thetaIndex])
const yVal = parseFloat(parts[intensityIndex])
if (!isNaN(xVal) && !isNaN(yVal)) {
x.push(xVal)
y.push(yVal)
}
}
} else {
// End of loop, no data found
inDataLoop = false
dataColumns = []
thetaIndex = -1
intensityIndex = -1
}
}
// Reset if we hit another loop_ or data block
if (inDataLoop && (line.startsWith('data_') || (line === 'loop_' && dataColumns.length > 0))) {
inDataLoop = false
}
}
return { x, y }
}
// Parse DIF or XY format (space-separated 2theta intensity)
const parseDIF = (text) => {
const lines = text.split('\n')
const x = []
const y = []
for (const line of lines) {
const trimmed = line.trim()
// Skip comment lines, headers, and metadata
if (!trimmed ||
trimmed.startsWith('#') ||
trimmed.startsWith('_') ||
trimmed.startsWith('CELL') ||
trimmed.startsWith('SPACE') ||
/^[a-zA-Z]/.test(trimmed)) { // Skip lines starting with letters (metadata)
continue
}
// Split by whitespace
const parts = trimmed.split(/\s+/)
if (parts.length >= 2) {
const xVal = parseFloat(parts[0])
const yVal = parseFloat(parts[1])
if (!isNaN(xVal) && !isNaN(yVal)) {
x.push(xVal)
y.push(yVal)
}
}
}
return { x, y }
}
// Parse uploaded file
const parseFile = (file) => {
return new Promise((resolve, reject) => {
const reader = new FileReader()
reader.onload = (e) => {
try {
const text = e.target.result
// Extract metadata (including wavelength)
const metadata = extractMetadata(text)
if (metadata.wavelength) {
setDetectedWavelength(metadata.wavelength)
setUserWavelength(metadata.wavelength)
setWavelengthSource('detected')
} else {
setDetectedWavelength(null)
setWavelengthSource('default')
}
// Determine file format and parse accordingly
const fileName = file.name.toLowerCase()
let data = { x: [], y: [] }
if (fileName.endsWith('.cif')) {
// CIF format - look for loop_ structures
data = parseCIF(text)
// Fallback to simple parsing if CIF parsing didn't find data
if (data.x.length === 0) {
console.log('CIF loop parsing failed, falling back to simple parser')
data = parseDIF(text)
}
} else {
// DIF, XY, CSV, TXT - simple space/comma separated
data = parseDIF(text)
}
if (data.x.length === 0 || data.y.length === 0) {
reject(new Error('No valid data points found in file'))
return
}
console.log(`Parsed ${data.x.length} data points from ${fileName}`)
resolve(data)
} catch (error) {
reject(error)
}
}
reader.onerror = () => reject(new Error('Failed to read file'))
reader.readAsText(file)
})
}
// Upload and parse file
const handleFileUpload = async (file) => {
try {
const data = await parseFile(file)
setRawData(data)
setFilename(file.name)
setModelResults(null) // Clear previous results
setAnalysisStatus('IDLE')
setIsLogitDrawerOpen(false) // Close logit drawer if open
return true
} catch (error) {
console.error('Error uploading file:', error)
alert(`Error loading file: ${error.message}`)
return false
}
}
// Send processed data to API for inference
const runInference = useCallback(async () => {
if (!processedData) {
alert('No data to analyze')
return
}
// Increment analysis counter - tracks button clicks
const currentCount = analysisCount + 1
setAnalysisCount(currentCount)
setIsLoading(true)
setAnalysisStatus('PROCESSING')
try {
const requestTimestamp = Date.now()
const response = await fetch('/api/predict', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
// Anti-caching headers
'Cache-Control': 'no-cache, no-store, must-revalidate',
'Pragma': 'no-cache',
'Expires': '0',
// Request tracking
'X-Request-ID': String(requestTimestamp),
'X-Filename': filename || 'unknown',
},
// Explicitly disable caching for this request
cache: 'no-store',
body: JSON.stringify({
x: processedData.x,
y: processedData.y,
// Include metadata to help track requests
metadata: {
timestamp: requestTimestamp,
filename: filename,
analysisCount: currentCount,
}
}),
})
if (!response.ok) {
throw new Error(`API error: ${response.status}`)
}
const results = await response.json()
setModelResults(results)
setAnalysisStatus('COMPLETE')
} catch (error) {
console.error('Error running inference:', error)
alert(`Inference failed: ${error.message}`)
setAnalysisStatus('IDLE')
} finally {
setIsLoading(false)
}
}, [processedData, analysisCount, filename])
// Load an example data file from the API
const loadExampleFile = useCallback(async (filename) => {
try {
const response = await fetch(`/api/examples/${encodeURIComponent(filename)}`)
if (!response.ok) {
throw new Error(`Failed to fetch example: ${response.status}`)
}
const text = await response.text()
// Extract metadata (including wavelength) — same as normal file upload
const metadata = extractMetadata(text)
if (metadata.wavelength) {
setDetectedWavelength(metadata.wavelength)
setUserWavelength(metadata.wavelength)
setWavelengthSource('detected')
} else {
setDetectedWavelength(null)
setWavelengthSource('default')
}
// Parse using the DIF parser (all examples are .dif)
const data = parseDIF(text)
if (data.x.length === 0 || data.y.length === 0) {
throw new Error('No valid data points found in example file')
}
setRawData(data)
setFilename(filename)
setModelResults(null)
setAnalysisStatus('IDLE')
setIsLogitDrawerOpen(false)
return true
} catch (error) {
console.error('Error loading example file:', error)
alert(`Error loading example: ${error.message}`)
return false
}
}, [])
// Reset application state
const handleReset = () => {
setRawData(null)
setFilename(null)
setModelResults(null)
setAnalysisStatus('IDLE')
setIsLogitDrawerOpen(false)
}
const value = {
rawData,
processedData,
modelResults,
isLoading,
filename,
analysisStatus,
detectedWavelength,
userWavelength,
setUserWavelength,
wavelengthSource,
dataWarnings,
baselineCorrection,
setBaselineCorrection,
interpolationEnabled,
setInterpolationEnabled,
scalingEnabled,
setScalingEnabled,
interpolationStrategy,
setInterpolationStrategy,
isLogitDrawerOpen,
setIsLogitDrawerOpen,
handleFileUpload,
loadExampleFile,
runInference,
handleReset,
MODEL_WAVELENGTH,
MODEL_MIN_2THETA,
MODEL_MAX_2THETA,
MODEL_INPUT_SIZE,
}
return <XRDContext.Provider value={value}>{children}</XRDContext.Provider>
}