geoclip-api / server.js
latterworks's picture
Upload server.js with huggingface_hub
dd6462a verified
import express from 'express';
import multer from 'multer';
import { AutoModel, AutoProcessor, RawImage, Tensor } from '@huggingface/transformers';
const MODEL_ID = 'Xenova/geoclip-large-patch14';
const PROCESSOR_ID = 'openai/clip-vit-large-patch14';
const GALLERY_URL = `https://huggingface.co/${MODEL_ID}/resolve/main/gps_gallery/coordinates_100K.json`;
const BATCH = 512;
const TOP_K = 5;
const HEATMAP_K = 100;
const EXP_LOGIT_SCALE = Math.exp(3.681034803390503);
const ALLOWED_MIME = new Set(['image/jpeg', 'image/jpg', 'image/png', 'image/webp']);
const app = express();
const upload = multer({
storage: multer.memoryStorage(),
limits: { fileSize: 20 * 1024 * 1024 },
fileFilter: (_req, file, cb) => {
ALLOWED_MIME.has(file.mimetype) ? cb(null, true) : cb(Object.assign(new Error('unsupported file type'), { status: 415 }));
},
});
let visionModel, locationModel, processor, gpsData, galleryEmbeds, embedDim;
function normalize(arr, dim) {
const result = new Float32Array(arr.length);
const rows = arr.length / dim;
for (let i = 0; i < rows; i++) {
let norm = 0;
for (let j = 0; j < dim; j++) norm += arr[i * dim + j] ** 2;
norm = Math.sqrt(norm);
for (let j = 0; j < dim; j++) result[i * dim + j] = arr[i * dim + j] / norm;
}
return result;
}
function scoreVsGallery(imgEmbed, gallery, galleryLen, dim) {
const scores = new Float32Array(galleryLen);
for (let i = 0; i < galleryLen; i++) {
let dot = 0;
const off = i * dim;
for (let j = 0; j < dim; j++) dot += imgEmbed[j] * gallery[off + j];
scores[i] = EXP_LOGIT_SCALE * dot;
}
return scores;
}
function softmax(scores) {
let max = -Infinity;
for (const s of scores) if (s > max) max = s;
let sum = 0;
const probs = new Float32Array(scores.length);
for (let i = 0; i < scores.length; i++) { probs[i] = Math.exp(scores[i] - max); sum += probs[i]; }
for (let i = 0; i < probs.length; i++) probs[i] /= sum;
return probs;
}
async function init() {
console.log('loading models...');
[visionModel, locationModel, processor] = await Promise.all([
AutoModel.from_pretrained(MODEL_ID, { model_file_name: 'vision_model_quantized' }),
AutoModel.from_pretrained(MODEL_ID, { model_file_name: 'location_model', quantized: false }),
AutoProcessor.from_pretrained(PROCESSOR_ID),
]);
console.log('fetching gps gallery...');
const res = await fetch(GALLERY_URL);
gpsData = await res.json();
console.log(`computing ${gpsData.length} gallery embeddings...`);
const chunks = [];
let totalDim = null;
for (let i = 0; i < gpsData.length; i += BATCH) {
const chunk = gpsData.slice(i, i + BATCH);
const { location_embeds } = await locationModel({
location: new Tensor('float32', chunk.flat(), [chunk.length, 2]),
});
const data = new Float32Array(location_embeds.data);
const dim = data.length / chunk.length;
if (totalDim === null) totalDim = dim;
chunks.push(data);
}
embedDim = totalDim;
galleryEmbeds = new Float32Array(gpsData.length * embedDim);
let offset = 0;
for (const c of chunks) { galleryEmbeds.set(c, offset); offset += c.length; }
galleryEmbeds = normalize(galleryEmbeds, embedDim);
console.log(`gallery ready shape=[${gpsData.length}, ${embedDim}]`);
}
app.post('/predict', (req, res, next) => upload.single('file')(req, res, err => {
if (err) return res.status(err.status ?? 400).json({ error: err.message });
next();
}), async (req, res) => {
if (!req.file) return res.status(400).json({ error: 'no file uploaded' });
try {
const t0 = Date.now();
const image = await RawImage.fromBlob(new Blob([req.file.buffer], { type: req.file.mimetype }));
const inputs = await processor(image);
const { image_embeds } = await visionModel(inputs);
const imgArr = new Float32Array(image_embeds.data);
const normImg = normalize(imgArr, embedDim);
const scores = scoreVsGallery(normImg, galleryEmbeds, gpsData.length, embedDim);
const probs = softmax(scores);
const sorted = Array.from(probs, (p, i) => ({ p, i })).sort((a, b) => b.p - a.p);
res.json({
predictions: sorted.slice(0, TOP_K).map(({ p, i }, rank) => ({
rank: rank + 1,
lat: gpsData[i][0],
lon: gpsData[i][1],
prob: p,
})),
heatmap: sorted.slice(0, HEATMAP_K).map(({ p, i }) => ({
lat: gpsData[i][0],
lon: gpsData[i][1],
prob: p,
})),
inference_ms: Date.now() - t0,
});
} catch (err) {
console.error(err);
res.status(500).json({ error: err.message });
}
});
const HTML = `<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>GeoCLIP</title>
<link rel="stylesheet" href="https://unpkg.com/leaflet@1.9.4/dist/leaflet.css">
<script src="https://unpkg.com/leaflet@1.9.4/dist/leaflet.js"><\/script>
<script src="https://unpkg.com/leaflet.heat/dist/leaflet-heat.js"><\/script>
<style>
*, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; }
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; background: #0f0f0f; color: #e8e8e8; min-height: 100vh; display: flex; flex-direction: column; align-items: center; padding: 48px 16px; }
header { text-align: center; margin-bottom: 40px; }
header h1 { font-size: 1.6rem; font-weight: 600; letter-spacing: -0.02em; color: #fff; }
header p { margin-top: 6px; font-size: 0.85rem; color: #666; }
.card { background: #181818; border: 1px solid #272727; border-radius: 12px; width: 100%; max-width: 520px; overflow: hidden; }
.drop-zone { padding: 40px 24px; display: flex; flex-direction: column; align-items: center; gap: 12px; cursor: pointer; border-bottom: 1px solid #272727; transition: background 0.15s; position: relative; }
.drop-zone:hover, .drop-zone.drag-over { background: #1f1f1f; }
.drop-zone input[type=file] { position: absolute; inset: 0; opacity: 0; cursor: pointer; }
.drop-icon { width: 40px; height: 40px; border-radius: 10px; background: #242424; border: 1px solid #333; display: flex; align-items: center; justify-content: center; font-size: 18px; }
.drop-zone span { font-size: 0.85rem; color: #888; }
.drop-zone span b { color: #ccc; font-weight: 500; }
#preview-wrap { display: none; padding: 16px; border-bottom: 1px solid #272727; }
#preview-wrap img { width: 100%; border-radius: 8px; max-height: 280px; object-fit: cover; }
.actions { padding: 16px; border-bottom: 1px solid #272727; display: none; }
button { width: 100%; padding: 10px; background: #fff; color: #000; border: none; border-radius: 8px; font-size: 0.9rem; font-weight: 500; cursor: pointer; transition: opacity 0.15s; }
button:hover { opacity: 0.85; }
button:disabled { opacity: 0.4; cursor: not-allowed; }
#results { display: none; }
.results-header { padding: 14px 16px 8px; font-size: 0.7rem; font-weight: 600; letter-spacing: 0.08em; text-transform: uppercase; color: #555; display: flex; justify-content: space-between; }
.prediction { padding: 12px 16px; display: flex; flex-direction: column; gap: 6px; border-top: 1px solid #1f1f1f; }
.prediction:first-of-type { border-top: none; }
.pred-row { display: flex; align-items: center; justify-content: space-between; gap: 12px; }
.rank { font-size: 0.72rem; font-weight: 600; color: #444; width: 20px; flex-shrink: 0; }
.coords { font-size: 0.88rem; font-variant-numeric: tabular-nums; color: #ddd; flex: 1; }
.prob-label { font-size: 0.78rem; color: #666; font-variant-numeric: tabular-nums; flex-shrink: 0; }
.bar-wrap { height: 3px; background: #222; border-radius: 2px; overflow: hidden; margin-left: 32px; }
.bar { height: 100%; background: #fff; border-radius: 2px; transition: width 0.4s ease; }
.prediction:nth-child(1) .bar { background: #fff; }
.prediction:nth-child(2) .bar { background: #aaa; }
.prediction:nth-child(3) .bar { background: #777; }
.prediction:nth-child(4) .bar { background: #555; }
.prediction:nth-child(5) .bar { background: #3a3a3a; }
.meta { padding: 10px 16px; font-size: 0.72rem; color: #444; border-top: 1px solid #1f1f1f; text-align: right; }
#map-wrap { border-bottom: 1px solid #272727; }
#map { height: 260px; background: #111; }
.leaflet-control-attribution { background: rgba(0,0,0,0.5) !important; color: #555 !important; font-size: 0.6rem !important; }
.leaflet-control-attribution a { color: #666 !important; }
#status { font-size: 0.8rem; color: #666; margin-top: 20px; min-height: 20px; }
</style>
</head>
<body>
<header>
<h1>GeoCLIP</h1>
<p>Upload an image to predict its location</p>
</header>
<div class="card">
<div class="drop-zone" id="drop-zone">
<input type="file" id="file-input" accept="image/*">
<div class="drop-icon">&#127757;</div>
<span><b>Click to upload</b> or drag and drop</span>
<span>JPG, PNG, WEBP</span>
</div>
<div id="preview-wrap"><img id="preview" src="" alt="preview"></div>
<div class="actions" id="actions"><button id="predict-btn">Predict location</button></div>
<div id="results">
<div class="results-header"><span>Predictions</span><span>Top 5</span></div>
<div id="map-wrap"><div id="map"></div></div>
<div id="predictions-list"></div>
<div class="meta" id="meta"></div>
</div>
</div>
<div id="status"></div>
<script>
const dropZone = document.getElementById('drop-zone');
const fileInput = document.getElementById('file-input');
const previewWrap = document.getElementById('preview-wrap');
const preview = document.getElementById('preview');
const actions = document.getElementById('actions');
const predictBtn = document.getElementById('predict-btn');
const results = document.getElementById('results');
const predList = document.getElementById('predictions-list');
const meta = document.getElementById('meta');
const status = document.getElementById('status');
let selectedFile = null;
let leafletMap = null;
let heatLayer = null;
let pinMarkers = [];
function setFile(file) {
if (!file || !file.type.startsWith('image/')) return;
selectedFile = file;
preview.src = URL.createObjectURL(file);
previewWrap.style.display = 'block';
actions.style.display = 'block';
results.style.display = 'none';
predList.innerHTML = '';
status.textContent = '';
}
fileInput.addEventListener('change', e => setFile(e.target.files[0]));
dropZone.addEventListener('dragover', e => { e.preventDefault(); dropZone.classList.add('drag-over'); });
dropZone.addEventListener('dragleave', () => dropZone.classList.remove('drag-over'));
dropZone.addEventListener('drop', e => { e.preventDefault(); dropZone.classList.remove('drag-over'); setFile(e.dataTransfer.files[0]); });
predictBtn.addEventListener('click', async () => {
if (!selectedFile) return;
predictBtn.disabled = true;
predictBtn.textContent = 'Predicting…';
status.textContent = '';
results.style.display = 'none';
const form = new FormData();
form.append('file', selectedFile);
try {
const res = await fetch('/predict', { method: 'POST', body: form });
if (!res.ok) throw new Error('server error ' + res.status);
const data = await res.json();
const maxProb = data.predictions[0].prob;
predList.innerHTML = data.predictions.map(p => \`
<div class="prediction">
<div class="pred-row">
<span class="rank">\${p.rank}</span>
<span class="coords">\${p.lat.toFixed(4)}, \${p.lon.toFixed(4)}</span>
<span class="prob-label">\${(p.prob * 100).toFixed(2)}%</span>
</div>
<div class="bar-wrap"><div class="bar" style="width:\${(p.prob / maxProb * 100).toFixed(1)}%"></div></div>
</div>
\`).join('');
meta.textContent = data.inference_ms + 'ms';
results.style.display = 'block';
const heatMax = data.heatmap[0].prob;
const heatPoints = data.heatmap.map(p => [p.lat, p.lon, p.prob / heatMax]);
setTimeout(() => {
if (!leafletMap) {
leafletMap = L.map('map', { zoomControl: true, attributionControl: true });
L.tileLayer('https://{s}.basemaps.cartocdn.com/dark_all/{z}/{x}/{y}{r}.png', {
maxZoom: 19,
attribution: '© <a href="https://carto.com/">CARTO</a>',
}).addTo(leafletMap);
}
leafletMap.invalidateSize();
if (heatLayer) leafletMap.removeLayer(heatLayer);
heatLayer = L.heatLayer(heatPoints, { radius: 28, blur: 20, maxZoom: 17, max: 1.0 }).addTo(leafletMap);
pinMarkers.forEach(m => m.remove());
pinMarkers = data.predictions.map((p, i) => L.circleMarker([p.lat, p.lon], {
radius: i === 0 ? 9 : 5,
fillColor: i === 0 ? '#ffffff' : '#888888',
color: i === 0 ? '#cccccc' : '#555555',
weight: i === 0 ? 2 : 1,
fillOpacity: i === 0 ? 1 : 0.65,
}).bindTooltip(\`#\${p.rank} \${p.lat.toFixed(4)}, \${p.lon.toFixed(4)}\`, { sticky: false, opacity: 0.9 }).addTo(leafletMap));
leafletMap.fitBounds(L.latLngBounds(heatPoints.map(p => [p[0], p[1]])).pad(0.25));
}, 0);
} catch (err) {
status.textContent = err.message;
} finally {
predictBtn.disabled = false;
predictBtn.textContent = 'Predict location';
}
});
<\/script>
</body>
</html>`;
app.get('/health', (_req, res) => res.json({ status: 'ok', gallery_size: gpsData?.length ?? 0 }));
app.get('/', (_req, res) => res.type('html').send(HTML));
await init();
app.listen(7860, () => console.log('listening on :7860'));