import express from 'express'; import multer from 'multer'; import { AutoModel, AutoProcessor, RawImage, Tensor } from '@huggingface/transformers'; const MODEL_ID = 'Xenova/geoclip-large-patch14'; const PROCESSOR_ID = 'openai/clip-vit-large-patch14'; const GALLERY_URL = `https://huggingface.co/${MODEL_ID}/resolve/main/gps_gallery/coordinates_100K.json`; const BATCH = 512; const TOP_K = 5; const HEATMAP_K = 100; const EXP_LOGIT_SCALE = Math.exp(3.681034803390503); const ALLOWED_MIME = new Set(['image/jpeg', 'image/jpg', 'image/png', 'image/webp']); const app = express(); const upload = multer({ storage: multer.memoryStorage(), limits: { fileSize: 20 * 1024 * 1024 }, fileFilter: (_req, file, cb) => { ALLOWED_MIME.has(file.mimetype) ? cb(null, true) : cb(Object.assign(new Error('unsupported file type'), { status: 415 })); }, }); let visionModel, locationModel, processor, gpsData, galleryEmbeds, embedDim; function normalize(arr, dim) { const result = new Float32Array(arr.length); const rows = arr.length / dim; for (let i = 0; i < rows; i++) { let norm = 0; for (let j = 0; j < dim; j++) norm += arr[i * dim + j] ** 2; norm = Math.sqrt(norm); for (let j = 0; j < dim; j++) result[i * dim + j] = arr[i * dim + j] / norm; } return result; } function scoreVsGallery(imgEmbed, gallery, galleryLen, dim) { const scores = new Float32Array(galleryLen); for (let i = 0; i < galleryLen; i++) { let dot = 0; const off = i * dim; for (let j = 0; j < dim; j++) dot += imgEmbed[j] * gallery[off + j]; scores[i] = EXP_LOGIT_SCALE * dot; } return scores; } function softmax(scores) { let max = -Infinity; for (const s of scores) if (s > max) max = s; let sum = 0; const probs = new Float32Array(scores.length); for (let i = 0; i < scores.length; i++) { probs[i] = Math.exp(scores[i] - max); sum += probs[i]; } for (let i = 0; i < probs.length; i++) probs[i] /= sum; return probs; } async function init() { console.log('loading models...'); [visionModel, locationModel, processor] = await Promise.all([ AutoModel.from_pretrained(MODEL_ID, { model_file_name: 'vision_model_quantized' }), AutoModel.from_pretrained(MODEL_ID, { model_file_name: 'location_model', quantized: false }), AutoProcessor.from_pretrained(PROCESSOR_ID), ]); console.log('fetching gps gallery...'); const res = await fetch(GALLERY_URL); gpsData = await res.json(); console.log(`computing ${gpsData.length} gallery embeddings...`); const chunks = []; let totalDim = null; for (let i = 0; i < gpsData.length; i += BATCH) { const chunk = gpsData.slice(i, i + BATCH); const { location_embeds } = await locationModel({ location: new Tensor('float32', chunk.flat(), [chunk.length, 2]), }); const data = new Float32Array(location_embeds.data); const dim = data.length / chunk.length; if (totalDim === null) totalDim = dim; chunks.push(data); } embedDim = totalDim; galleryEmbeds = new Float32Array(gpsData.length * embedDim); let offset = 0; for (const c of chunks) { galleryEmbeds.set(c, offset); offset += c.length; } galleryEmbeds = normalize(galleryEmbeds, embedDim); console.log(`gallery ready shape=[${gpsData.length}, ${embedDim}]`); } app.post('/predict', (req, res, next) => upload.single('file')(req, res, err => { if (err) return res.status(err.status ?? 400).json({ error: err.message }); next(); }), async (req, res) => { if (!req.file) return res.status(400).json({ error: 'no file uploaded' }); try { const t0 = Date.now(); const image = await RawImage.fromBlob(new Blob([req.file.buffer], { type: req.file.mimetype })); const inputs = await processor(image); const { image_embeds } = await visionModel(inputs); const imgArr = new Float32Array(image_embeds.data); const normImg = normalize(imgArr, embedDim); const scores = scoreVsGallery(normImg, galleryEmbeds, gpsData.length, embedDim); const probs = softmax(scores); const sorted = Array.from(probs, (p, i) => ({ p, i })).sort((a, b) => b.p - a.p); res.json({ predictions: sorted.slice(0, TOP_K).map(({ p, i }, rank) => ({ rank: rank + 1, lat: gpsData[i][0], lon: gpsData[i][1], prob: p, })), heatmap: sorted.slice(0, HEATMAP_K).map(({ p, i }) => ({ lat: gpsData[i][0], lon: gpsData[i][1], prob: p, })), inference_ms: Date.now() - t0, }); } catch (err) { console.error(err); res.status(500).json({ error: err.message }); } }); const HTML = `