Spaces:
Running
Running
| import express from 'express'; | |
| import multer from 'multer'; | |
| import { AutoModel, AutoProcessor, RawImage, Tensor } from '@huggingface/transformers'; | |
| const MODEL_ID = 'Xenova/geoclip-large-patch14'; | |
| const PROCESSOR_ID = 'openai/clip-vit-large-patch14'; | |
| const GALLERY_URL = `https://huggingface.co/${MODEL_ID}/resolve/main/gps_gallery/coordinates_100K.json`; | |
| const BATCH = 512; | |
| const TOP_K = 5; | |
| const HEATMAP_K = 100; | |
| const EXP_LOGIT_SCALE = Math.exp(3.681034803390503); | |
| const ALLOWED_MIME = new Set(['image/jpeg', 'image/jpg', 'image/png', 'image/webp']); | |
| const app = express(); | |
| const upload = multer({ | |
| storage: multer.memoryStorage(), | |
| limits: { fileSize: 20 * 1024 * 1024 }, | |
| fileFilter: (_req, file, cb) => { | |
| ALLOWED_MIME.has(file.mimetype) ? cb(null, true) : cb(Object.assign(new Error('unsupported file type'), { status: 415 })); | |
| }, | |
| }); | |
| let visionModel, locationModel, processor, gpsData, galleryEmbeds, embedDim; | |
| function normalize(arr, dim) { | |
| const result = new Float32Array(arr.length); | |
| const rows = arr.length / dim; | |
| for (let i = 0; i < rows; i++) { | |
| let norm = 0; | |
| for (let j = 0; j < dim; j++) norm += arr[i * dim + j] ** 2; | |
| norm = Math.sqrt(norm); | |
| for (let j = 0; j < dim; j++) result[i * dim + j] = arr[i * dim + j] / norm; | |
| } | |
| return result; | |
| } | |
| function scoreVsGallery(imgEmbed, gallery, galleryLen, dim) { | |
| const scores = new Float32Array(galleryLen); | |
| for (let i = 0; i < galleryLen; i++) { | |
| let dot = 0; | |
| const off = i * dim; | |
| for (let j = 0; j < dim; j++) dot += imgEmbed[j] * gallery[off + j]; | |
| scores[i] = EXP_LOGIT_SCALE * dot; | |
| } | |
| return scores; | |
| } | |
| function softmax(scores) { | |
| let max = -Infinity; | |
| for (const s of scores) if (s > max) max = s; | |
| let sum = 0; | |
| const probs = new Float32Array(scores.length); | |
| for (let i = 0; i < scores.length; i++) { probs[i] = Math.exp(scores[i] - max); sum += probs[i]; } | |
| for (let i = 0; i < probs.length; i++) probs[i] /= sum; | |
| return probs; | |
| } | |
| async function init() { | |
| console.log('loading models...'); | |
| [visionModel, locationModel, processor] = await Promise.all([ | |
| AutoModel.from_pretrained(MODEL_ID, { model_file_name: 'vision_model_quantized' }), | |
| AutoModel.from_pretrained(MODEL_ID, { model_file_name: 'location_model', quantized: false }), | |
| AutoProcessor.from_pretrained(PROCESSOR_ID), | |
| ]); | |
| console.log('fetching gps gallery...'); | |
| const res = await fetch(GALLERY_URL); | |
| gpsData = await res.json(); | |
| console.log(`computing ${gpsData.length} gallery embeddings...`); | |
| const chunks = []; | |
| let totalDim = null; | |
| for (let i = 0; i < gpsData.length; i += BATCH) { | |
| const chunk = gpsData.slice(i, i + BATCH); | |
| const { location_embeds } = await locationModel({ | |
| location: new Tensor('float32', chunk.flat(), [chunk.length, 2]), | |
| }); | |
| const data = new Float32Array(location_embeds.data); | |
| const dim = data.length / chunk.length; | |
| if (totalDim === null) totalDim = dim; | |
| chunks.push(data); | |
| } | |
| embedDim = totalDim; | |
| galleryEmbeds = new Float32Array(gpsData.length * embedDim); | |
| let offset = 0; | |
| for (const c of chunks) { galleryEmbeds.set(c, offset); offset += c.length; } | |
| galleryEmbeds = normalize(galleryEmbeds, embedDim); | |
| console.log(`gallery ready shape=[${gpsData.length}, ${embedDim}]`); | |
| } | |
| app.post('/predict', (req, res, next) => upload.single('file')(req, res, err => { | |
| if (err) return res.status(err.status ?? 400).json({ error: err.message }); | |
| next(); | |
| }), async (req, res) => { | |
| if (!req.file) return res.status(400).json({ error: 'no file uploaded' }); | |
| try { | |
| const t0 = Date.now(); | |
| const image = await RawImage.fromBlob(new Blob([req.file.buffer], { type: req.file.mimetype })); | |
| const inputs = await processor(image); | |
| const { image_embeds } = await visionModel(inputs); | |
| const imgArr = new Float32Array(image_embeds.data); | |
| const normImg = normalize(imgArr, embedDim); | |
| const scores = scoreVsGallery(normImg, galleryEmbeds, gpsData.length, embedDim); | |
| const probs = softmax(scores); | |
| const sorted = Array.from(probs, (p, i) => ({ p, i })).sort((a, b) => b.p - a.p); | |
| res.json({ | |
| predictions: sorted.slice(0, TOP_K).map(({ p, i }, rank) => ({ | |
| rank: rank + 1, | |
| lat: gpsData[i][0], | |
| lon: gpsData[i][1], | |
| prob: p, | |
| })), | |
| heatmap: sorted.slice(0, HEATMAP_K).map(({ p, i }) => ({ | |
| lat: gpsData[i][0], | |
| lon: gpsData[i][1], | |
| prob: p, | |
| })), | |
| inference_ms: Date.now() - t0, | |
| }); | |
| } catch (err) { | |
| console.error(err); | |
| res.status(500).json({ error: err.message }); | |
| } | |
| }); | |
| const HTML = `<!DOCTYPE html> | |
| <html lang="en"> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
| <title>GeoCLIP</title> | |
| <link rel="stylesheet" href="https://unpkg.com/leaflet@1.9.4/dist/leaflet.css"> | |
| <script src="https://unpkg.com/leaflet@1.9.4/dist/leaflet.js"><\/script> | |
| <script src="https://unpkg.com/leaflet.heat/dist/leaflet-heat.js"><\/script> | |
| <style> | |
| *, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; } | |
| body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; background: #0f0f0f; color: #e8e8e8; min-height: 100vh; display: flex; flex-direction: column; align-items: center; padding: 48px 16px; } | |
| header { text-align: center; margin-bottom: 40px; } | |
| header h1 { font-size: 1.6rem; font-weight: 600; letter-spacing: -0.02em; color: #fff; } | |
| header p { margin-top: 6px; font-size: 0.85rem; color: #666; } | |
| .card { background: #181818; border: 1px solid #272727; border-radius: 12px; width: 100%; max-width: 520px; overflow: hidden; } | |
| .drop-zone { padding: 40px 24px; display: flex; flex-direction: column; align-items: center; gap: 12px; cursor: pointer; border-bottom: 1px solid #272727; transition: background 0.15s; position: relative; } | |
| .drop-zone:hover, .drop-zone.drag-over { background: #1f1f1f; } | |
| .drop-zone input[type=file] { position: absolute; inset: 0; opacity: 0; cursor: pointer; } | |
| .drop-icon { width: 40px; height: 40px; border-radius: 10px; background: #242424; border: 1px solid #333; display: flex; align-items: center; justify-content: center; font-size: 18px; } | |
| .drop-zone span { font-size: 0.85rem; color: #888; } | |
| .drop-zone span b { color: #ccc; font-weight: 500; } | |
| #preview-wrap { display: none; padding: 16px; border-bottom: 1px solid #272727; } | |
| #preview-wrap img { width: 100%; border-radius: 8px; max-height: 280px; object-fit: cover; } | |
| .actions { padding: 16px; border-bottom: 1px solid #272727; display: none; } | |
| button { width: 100%; padding: 10px; background: #fff; color: #000; border: none; border-radius: 8px; font-size: 0.9rem; font-weight: 500; cursor: pointer; transition: opacity 0.15s; } | |
| button:hover { opacity: 0.85; } | |
| button:disabled { opacity: 0.4; cursor: not-allowed; } | |
| #results { display: none; } | |
| .results-header { padding: 14px 16px 8px; font-size: 0.7rem; font-weight: 600; letter-spacing: 0.08em; text-transform: uppercase; color: #555; display: flex; justify-content: space-between; } | |
| .prediction { padding: 12px 16px; display: flex; flex-direction: column; gap: 6px; border-top: 1px solid #1f1f1f; } | |
| .prediction:first-of-type { border-top: none; } | |
| .pred-row { display: flex; align-items: center; justify-content: space-between; gap: 12px; } | |
| .rank { font-size: 0.72rem; font-weight: 600; color: #444; width: 20px; flex-shrink: 0; } | |
| .coords { font-size: 0.88rem; font-variant-numeric: tabular-nums; color: #ddd; flex: 1; } | |
| .prob-label { font-size: 0.78rem; color: #666; font-variant-numeric: tabular-nums; flex-shrink: 0; } | |
| .bar-wrap { height: 3px; background: #222; border-radius: 2px; overflow: hidden; margin-left: 32px; } | |
| .bar { height: 100%; background: #fff; border-radius: 2px; transition: width 0.4s ease; } | |
| .prediction:nth-child(1) .bar { background: #fff; } | |
| .prediction:nth-child(2) .bar { background: #aaa; } | |
| .prediction:nth-child(3) .bar { background: #777; } | |
| .prediction:nth-child(4) .bar { background: #555; } | |
| .prediction:nth-child(5) .bar { background: #3a3a3a; } | |
| .meta { padding: 10px 16px; font-size: 0.72rem; color: #444; border-top: 1px solid #1f1f1f; text-align: right; } | |
| #map-wrap { border-bottom: 1px solid #272727; } | |
| #map { height: 260px; background: #111; } | |
| .leaflet-control-attribution { background: rgba(0,0,0,0.5) !important; color: #555 !important; font-size: 0.6rem !important; } | |
| .leaflet-control-attribution a { color: #666 !important; } | |
| #status { font-size: 0.8rem; color: #666; margin-top: 20px; min-height: 20px; } | |
| </style> | |
| </head> | |
| <body> | |
| <header> | |
| <h1>GeoCLIP</h1> | |
| <p>Upload an image to predict its location</p> | |
| </header> | |
| <div class="card"> | |
| <div class="drop-zone" id="drop-zone"> | |
| <input type="file" id="file-input" accept="image/*"> | |
| <div class="drop-icon">🌍</div> | |
| <span><b>Click to upload</b> or drag and drop</span> | |
| <span>JPG, PNG, WEBP</span> | |
| </div> | |
| <div id="preview-wrap"><img id="preview" src="" alt="preview"></div> | |
| <div class="actions" id="actions"><button id="predict-btn">Predict location</button></div> | |
| <div id="results"> | |
| <div class="results-header"><span>Predictions</span><span>Top 5</span></div> | |
| <div id="map-wrap"><div id="map"></div></div> | |
| <div id="predictions-list"></div> | |
| <div class="meta" id="meta"></div> | |
| </div> | |
| </div> | |
| <div id="status"></div> | |
| <script> | |
| const dropZone = document.getElementById('drop-zone'); | |
| const fileInput = document.getElementById('file-input'); | |
| const previewWrap = document.getElementById('preview-wrap'); | |
| const preview = document.getElementById('preview'); | |
| const actions = document.getElementById('actions'); | |
| const predictBtn = document.getElementById('predict-btn'); | |
| const results = document.getElementById('results'); | |
| const predList = document.getElementById('predictions-list'); | |
| const meta = document.getElementById('meta'); | |
| const status = document.getElementById('status'); | |
| let selectedFile = null; | |
| let leafletMap = null; | |
| let heatLayer = null; | |
| let pinMarkers = []; | |
| function setFile(file) { | |
| if (!file || !file.type.startsWith('image/')) return; | |
| selectedFile = file; | |
| preview.src = URL.createObjectURL(file); | |
| previewWrap.style.display = 'block'; | |
| actions.style.display = 'block'; | |
| results.style.display = 'none'; | |
| predList.innerHTML = ''; | |
| status.textContent = ''; | |
| } | |
| fileInput.addEventListener('change', e => setFile(e.target.files[0])); | |
| dropZone.addEventListener('dragover', e => { e.preventDefault(); dropZone.classList.add('drag-over'); }); | |
| dropZone.addEventListener('dragleave', () => dropZone.classList.remove('drag-over')); | |
| dropZone.addEventListener('drop', e => { e.preventDefault(); dropZone.classList.remove('drag-over'); setFile(e.dataTransfer.files[0]); }); | |
| predictBtn.addEventListener('click', async () => { | |
| if (!selectedFile) return; | |
| predictBtn.disabled = true; | |
| predictBtn.textContent = 'Predicting…'; | |
| status.textContent = ''; | |
| results.style.display = 'none'; | |
| const form = new FormData(); | |
| form.append('file', selectedFile); | |
| try { | |
| const res = await fetch('/predict', { method: 'POST', body: form }); | |
| if (!res.ok) throw new Error('server error ' + res.status); | |
| const data = await res.json(); | |
| const maxProb = data.predictions[0].prob; | |
| predList.innerHTML = data.predictions.map(p => \` | |
| <div class="prediction"> | |
| <div class="pred-row"> | |
| <span class="rank">\${p.rank}</span> | |
| <span class="coords">\${p.lat.toFixed(4)}, \${p.lon.toFixed(4)}</span> | |
| <span class="prob-label">\${(p.prob * 100).toFixed(2)}%</span> | |
| </div> | |
| <div class="bar-wrap"><div class="bar" style="width:\${(p.prob / maxProb * 100).toFixed(1)}%"></div></div> | |
| </div> | |
| \`).join(''); | |
| meta.textContent = data.inference_ms + 'ms'; | |
| results.style.display = 'block'; | |
| const heatMax = data.heatmap[0].prob; | |
| const heatPoints = data.heatmap.map(p => [p.lat, p.lon, p.prob / heatMax]); | |
| setTimeout(() => { | |
| if (!leafletMap) { | |
| leafletMap = L.map('map', { zoomControl: true, attributionControl: true }); | |
| L.tileLayer('https://{s}.basemaps.cartocdn.com/dark_all/{z}/{x}/{y}{r}.png', { | |
| maxZoom: 19, | |
| attribution: '© <a href="https://carto.com/">CARTO</a>', | |
| }).addTo(leafletMap); | |
| } | |
| leafletMap.invalidateSize(); | |
| if (heatLayer) leafletMap.removeLayer(heatLayer); | |
| heatLayer = L.heatLayer(heatPoints, { radius: 28, blur: 20, maxZoom: 17, max: 1.0 }).addTo(leafletMap); | |
| pinMarkers.forEach(m => m.remove()); | |
| pinMarkers = data.predictions.map((p, i) => L.circleMarker([p.lat, p.lon], { | |
| radius: i === 0 ? 9 : 5, | |
| fillColor: i === 0 ? '#ffffff' : '#888888', | |
| color: i === 0 ? '#cccccc' : '#555555', | |
| weight: i === 0 ? 2 : 1, | |
| fillOpacity: i === 0 ? 1 : 0.65, | |
| }).bindTooltip(\`#\${p.rank} \${p.lat.toFixed(4)}, \${p.lon.toFixed(4)}\`, { sticky: false, opacity: 0.9 }).addTo(leafletMap)); | |
| leafletMap.fitBounds(L.latLngBounds(heatPoints.map(p => [p[0], p[1]])).pad(0.25)); | |
| }, 0); | |
| } catch (err) { | |
| status.textContent = err.message; | |
| } finally { | |
| predictBtn.disabled = false; | |
| predictBtn.textContent = 'Predict location'; | |
| } | |
| }); | |
| <\/script> | |
| </body> | |
| </html>`; | |
| app.get('/health', (_req, res) => res.json({ status: 'ok', gallery_size: gpsData?.length ?? 0 })); | |
| app.get('/', (_req, res) => res.type('html').send(HTML)); | |
| await init(); | |
| app.listen(7860, () => console.log('listening on :7860')); | |