anycalib-wasm / index.js
SebRincon's picture
Upload AnyCalib WASM demo (ONNX Runtime Web)
1313f05 verified
/**
* AnyCalib WASM — Camera calibration inference via ONNX Runtime Web.
*
* This module loads the AnyCalib ONNX model and runs inference in the browser
* using WebAssembly (WASM) or WebGPU backends.
*
* Usage:
* import { AnyCalibrayor } from './index.js';
* const calibrator = new AnyCalibrator();
* await calibrator.init();
* const result = await calibrator.predict(imageElement);
*/
import * as ort from 'onnxruntime-web';
// Configure WASM paths
ort.env.wasm.wasmPaths = 'https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/';
const MODEL_URL = 'https://huggingface.co/SebRincon/anycalib-onnx/resolve/main/model_int8.onnx';
const INPUT_SIZE = 518;
export class AnyCalibrator {
constructor(options = {}) {
this.modelUrl = options.modelUrl || MODEL_URL;
this.inputSize = options.inputSize || INPUT_SIZE;
this.session = null;
this.executionProvider = options.executionProvider || 'wasm';
}
async init() {
console.log(`[AnyCalib] Loading model from ${this.modelUrl}...`);
console.log(`[AnyCalib] Using ${this.executionProvider} backend`);
const startTime = performance.now();
this.session = await ort.InferenceSession.create(this.modelUrl, {
executionProviders: [this.executionProvider],
graphOptimizationLevel: 'all',
});
const elapsed = ((performance.now() - startTime) / 1000).toFixed(1);
console.log(`[AnyCalib] Model loaded in ${elapsed}s`);
return this;
}
/**
* Preprocess an image element or canvas to a float32 tensor.
* Resizes to inputSize x inputSize and normalizes to [0, 1].
*/
preprocessImage(imageSource) {
const canvas = document.createElement('canvas');
canvas.width = this.inputSize;
canvas.height = this.inputSize;
const ctx = canvas.getContext('2d');
ctx.drawImage(imageSource, 0, 0, this.inputSize, this.inputSize);
const imageData = ctx.getImageData(0, 0, this.inputSize, this.inputSize);
const { data, width, height } = imageData;
// Convert RGBA HWC → RGB CHW float32 [0,1]
const float32Data = new Float32Array(3 * width * height);
for (let i = 0; i < width * height; i++) {
float32Data[i] = data[i * 4] / 255.0; // R
float32Data[width * height + i] = data[i * 4 + 1] / 255.0; // G
float32Data[2 * width * height + i] = data[i * 4 + 2] / 255.0; // B
}
return new ort.Tensor('float32', float32Data, [1, 3, height, width]);
}
/**
* Run inference on an image element, canvas, or video frame.
* Returns { rays, tangentCoords, elapsed }.
*/
async predict(imageSource) {
if (!this.session) {
throw new Error('Model not initialized. Call init() first.');
}
const inputTensor = this.preprocessImage(imageSource);
const startTime = performance.now();
const results = await this.session.run({ image: inputTensor });
const elapsed = performance.now() - startTime;
return {
rays: results.rays,
tangentCoords: results.tangent_coords,
elapsed,
};
}
/**
* Compute a simple distortion heatmap from ray predictions.
* Returns a Uint8ClampedArray (H*W) with distortion magnitude per pixel.
*/
computeDistortionMap(rays) {
const [batch, channels, height, width] = rays.dims;
const data = rays.data;
const heatmap = new Uint8ClampedArray(height * width);
for (let y = 0; y < height; y++) {
for (let x = 0; x < width; x++) {
const idx = y * width + x;
const rx = data[idx]; // channel 0
const ry = data[height * width + idx]; // channel 1
const rz = data[2 * height * width + idx]; // channel 2
// Deviation from pinhole (rz=1 for undistorted center rays)
const deviation = Math.sqrt(rx * rx + ry * ry) / Math.max(Math.abs(rz), 1e-6);
heatmap[idx] = Math.min(255, Math.floor(deviation * 128));
}
}
return { data: heatmap, width, height };
}
}
// For script tag usage (non-module)
if (typeof window !== 'undefined') {
window.AnyCalibrator = AnyCalibrator;
}