Spaces:
Running
Running
| from fastapi import FastAPI, HTTPException, Depends, UploadFile, File, Form, Request, WebSocket | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import StreamingResponse | |
| import sqlite3, os, uuid, time, json, hashlib, secrets, shutil, asyncio, threading, io, subprocess, sys | |
| from typing import Optional | |
| from datetime import datetime | |
| from pathlib import Path | |
| # ββ ensure SigLIP2 dependencies are installed ββββββββββββββββββββββββββββββββ | |
| def _ensure_deps(): | |
| deps = [] | |
| try: | |
| import sentencepiece | |
| except ImportError: | |
| deps.append("sentencepiece") | |
| try: | |
| import google.protobuf | |
| except ImportError: | |
| deps.append("protobuf") | |
| try: | |
| import clip | |
| except ImportError: | |
| deps.append("openai-clip") | |
| try: | |
| import cv2 | |
| except ImportError: | |
| deps.append("opencv-python-headless") | |
| if deps: | |
| subprocess.run([sys.executable, "-m", "pip", "install", "-q"] + deps, check=True) | |
| _ensure_deps() | |
| # ββ CONFIG ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| SECRET_KEY = os.environ.get("SECRET_KEY", "findit-dev-secret") | |
| DB_PATH = "/app/data/findit.db" | |
| IMG_DIR = "/app/data/images" | |
| TOKEN_TTL = 60 * 60 * 24 * 30 # 30 days | |
| Path(IMG_DIR).mkdir(parents=True, exist_ok=True) | |
| app = FastAPI(title="FindIt API") | |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, | |
| allow_methods=["*"], allow_headers=["*"]) | |
| app.mount("/images", StaticFiles(directory=IMG_DIR), name="images") | |
| # ββ PUB/SUB BROKER ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class Broker: | |
| def __init__(self): | |
| self.listeners: dict[str, list[asyncio.Queue]] = {} # channel β queues | |
| def subscribe(self, channel: str) -> asyncio.Queue: | |
| q = asyncio.Queue(maxsize=50) | |
| self.listeners.setdefault(channel, []).append(q) | |
| return q | |
| def unsubscribe(self, channel: str, q: asyncio.Queue): | |
| lst = self.listeners.get(channel, []) | |
| if q in lst: lst.remove(q) | |
| def publish(self, channel: str, data: dict): | |
| msg = f"data: {json.dumps(data)}\n\n" | |
| try: | |
| loop = asyncio.get_event_loop() | |
| except RuntimeError: | |
| loop = None | |
| def _put(queues): | |
| for q in list(queues): | |
| try: q.put_nowait(msg) | |
| except asyncio.QueueFull: pass | |
| def _send(): | |
| _put(self.listeners.get(channel, [])) | |
| # Also forward to "all" so post/comment SSE still works | |
| if channel != "all": | |
| _put(self.listeners.get("all", [])) | |
| if loop and loop.is_running(): | |
| # Called from a background thread β schedule on the event loop | |
| loop.call_soon_threadsafe(_send) | |
| else: | |
| _send() | |
| broker = Broker() | |
| # ββ DATABASE ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_db(): | |
| conn = sqlite3.connect(DB_PATH) | |
| conn.row_factory = sqlite3.Row | |
| conn.execute("PRAGMA journal_mode=WAL") | |
| return conn | |
| def init_db(): | |
| db = get_db() | |
| db.executescript(""" | |
| CREATE TABLE IF NOT EXISTS profiles ( | |
| id TEXT PRIMARY KEY, | |
| uid TEXT UNIQUE NOT NULL, | |
| name TEXT NOT NULL, | |
| initials TEXT NOT NULL, | |
| color TEXT NOT NULL DEFAULT '#5b8dff', | |
| role TEXT NOT NULL DEFAULT 'user', | |
| is_banned INTEGER NOT NULL DEFAULT 0, | |
| points INTEGER NOT NULL DEFAULT 0, | |
| created_at TEXT NOT NULL DEFAULT (datetime('now')) | |
| ); | |
| CREATE TABLE IF NOT EXISTS passwords ( | |
| user_id TEXT PRIMARY KEY REFERENCES profiles(id), | |
| hash TEXT NOT NULL | |
| ); | |
| CREATE TABLE IF NOT EXISTS sessions ( | |
| token TEXT PRIMARY KEY, | |
| user_id TEXT NOT NULL REFERENCES profiles(id), | |
| expires_at INTEGER NOT NULL | |
| ); | |
| CREATE TABLE IF NOT EXISTS posts ( | |
| id TEXT PRIMARY KEY, | |
| author_id TEXT NOT NULL REFERENCES profiles(id), | |
| title TEXT NOT NULL, | |
| description TEXT NOT NULL DEFAULT '', | |
| location TEXT NOT NULL, | |
| category TEXT NOT NULL, | |
| status TEXT NOT NULL DEFAULT 'found', | |
| image_url TEXT, | |
| is_deleted INTEGER NOT NULL DEFAULT 0, | |
| created_at TEXT NOT NULL DEFAULT (datetime('now')), | |
| updated_at TEXT NOT NULL DEFAULT (datetime('now')) | |
| ); | |
| CREATE TABLE IF NOT EXISTS comments ( | |
| id TEXT PRIMARY KEY, | |
| post_id TEXT NOT NULL REFERENCES posts(id), | |
| author_id TEXT NOT NULL REFERENCES profiles(id), | |
| parent_id TEXT REFERENCES comments(id), | |
| body TEXT NOT NULL DEFAULT '', | |
| image_url TEXT, | |
| created_at TEXT NOT NULL DEFAULT (datetime('now')) | |
| ); | |
| CREATE TABLE IF NOT EXISTS mod_log ( | |
| id TEXT PRIMARY KEY, | |
| admin_id TEXT NOT NULL REFERENCES profiles(id), | |
| action TEXT NOT NULL, | |
| target_id TEXT, | |
| post_id TEXT, | |
| note TEXT, | |
| created_at TEXT NOT NULL DEFAULT (datetime('now')) | |
| ); | |
| CREATE TABLE IF NOT EXISTS admin_requests ( | |
| id TEXT PRIMARY KEY, | |
| user_id TEXT NOT NULL REFERENCES profiles(id), | |
| email TEXT NOT NULL, | |
| name TEXT NOT NULL, | |
| role_title TEXT NOT NULL, | |
| reason TEXT NOT NULL, | |
| id_image_url TEXT, | |
| status TEXT NOT NULL DEFAULT 'pending', | |
| created_at TEXT NOT NULL DEFAULT (datetime('now')) | |
| ); | |
| CREATE TABLE IF NOT EXISTS alerts ( | |
| id TEXT PRIMARY KEY, | |
| user_id TEXT NOT NULL REFERENCES profiles(id), | |
| admin_id TEXT NOT NULL REFERENCES profiles(id), | |
| note TEXT NOT NULL DEFAULT '', | |
| expires_at TEXT NOT NULL, | |
| created_at TEXT NOT NULL DEFAULT (datetime('now')) | |
| ); | |
| CREATE TABLE IF NOT EXISTS reports ( | |
| id TEXT PRIMARY KEY, | |
| post_id TEXT NOT NULL REFERENCES posts(id), | |
| reporter_id TEXT NOT NULL REFERENCES profiles(id), | |
| reason TEXT NOT NULL DEFAULT '', | |
| created_at TEXT NOT NULL DEFAULT (datetime('now')), | |
| UNIQUE(post_id, reporter_id) | |
| ); | |
| CREATE TABLE IF NOT EXISTS dms ( | |
| id TEXT PRIMARY KEY, | |
| sender_id TEXT NOT NULL REFERENCES profiles(id), | |
| receiver_id TEXT NOT NULL REFERENCES profiles(id), | |
| body TEXT NOT NULL DEFAULT '', | |
| image_url TEXT, | |
| read INTEGER NOT NULL DEFAULT 0, | |
| created_at TEXT NOT NULL DEFAULT (datetime('now')) | |
| ); | |
| CREATE TABLE IF NOT EXISTS comment_votes ( | |
| id TEXT PRIMARY KEY, | |
| comment_id TEXT NOT NULL REFERENCES comments(id) ON DELETE CASCADE, | |
| user_id TEXT NOT NULL REFERENCES profiles(id), | |
| vote INTEGER NOT NULL, | |
| created_at TEXT NOT NULL DEFAULT (datetime('now')), | |
| UNIQUE(comment_id, user_id) | |
| ); | |
| """) | |
| db.commit() | |
| db.close() | |
| init_db() | |
| # ββ HELPERS βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def hash_password(pw: str) -> str: | |
| return hashlib.sha256((pw + SECRET_KEY).encode()).hexdigest() | |
| def make_token() -> str: | |
| return secrets.token_hex(32) | |
| COLORS = ["#5b8dff","#22c97a","#a084f5","#ff6b6b","#f5a623","#4da6ff","#e8385a","#00c9a7"] | |
| def pick_color(uid: str) -> str: | |
| return COLORS[sum(ord(c) for c in uid) % len(COLORS)] | |
| def get_current_user(request: Request): | |
| auth = request.headers.get("Authorization", "") | |
| if not auth.startswith("Bearer "): return None | |
| token = auth[7:] | |
| db = get_db() | |
| row = db.execute( | |
| "SELECT p.* FROM sessions s JOIN profiles p ON p.id=s.user_id " | |
| "WHERE s.token=? AND s.expires_at>?", (token, int(time.time())) | |
| ).fetchone() | |
| db.close() | |
| return dict(row) if row else None | |
| def require_user(request: Request): | |
| u = get_current_user(request) | |
| if not u: raise HTTPException(401, "Not authenticated") | |
| return u | |
| def require_admin(request: Request): | |
| u = require_user(request) | |
| if u["role"] not in ("admin","super_admin"): raise HTTPException(403, "Not admin") | |
| return u | |
| def create_session(user_id: str) -> str: | |
| token = make_token() | |
| expires = int(time.time()) + TOKEN_TTL | |
| db = get_db() | |
| db.execute("INSERT INTO sessions (token,user_id,expires_at) VALUES (?,?,?)", | |
| (token, user_id, expires)) | |
| db.commit() | |
| db.close() | |
| return token | |
| # ββ AUTH ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def register(data: dict): | |
| uid = data.get("uid","").strip().lower() | |
| pw = data.get("password","").strip() | |
| if not uid or not pw: | |
| raise HTTPException(400, "Username and password required") | |
| if len(uid) < 3: | |
| raise HTTPException(400, "Username must be at least 3 characters") | |
| if len(pw) < 6: | |
| raise HTTPException(400, "Password must be at least 6 characters") | |
| if not uid.replace("_","").isalnum(): | |
| raise HTTPException(400, "Username can only contain letters, numbers, underscores") | |
| db = get_db() | |
| if db.execute("SELECT 1 FROM profiles WHERE uid=?", (uid,)).fetchone(): | |
| db.close() | |
| raise HTTPException(409, "Username already taken") | |
| pid = str(uuid.uuid4()) | |
| name = data.get("name", uid).strip() or uid | |
| initials = name[:2].upper() | |
| color = pick_color(uid) | |
| # First user ever registered becomes super_admin automatically | |
| user_count = db.execute("SELECT COUNT(*) FROM profiles").fetchone()[0] | |
| role = 'super_admin' if user_count == 0 else 'user' | |
| db.execute("INSERT INTO profiles (id,uid,name,initials,color,role) VALUES (?,?,?,?,?,?)", | |
| (pid, uid, name, initials, color, role)) | |
| db.execute("INSERT INTO passwords (user_id,hash) VALUES (?,?)", | |
| (pid, hash_password(pw))) | |
| db.commit() | |
| profile = dict(db.execute("SELECT * FROM profiles WHERE id=?", (pid,)).fetchone()) | |
| db.close() | |
| token = create_session(pid) | |
| return {"token": token, "profile": profile} | |
| async def login(data: dict): | |
| uid = data.get("uid","").strip().lower() | |
| pw = data.get("password","").strip() | |
| if not uid or not pw: | |
| raise HTTPException(400, "Username and password required") | |
| db = get_db() | |
| profile = db.execute("SELECT * FROM profiles WHERE uid=?", (uid,)).fetchone() | |
| if not profile: | |
| db.close() | |
| raise HTTPException(401, "Wrong username or password") | |
| pw_row = db.execute("SELECT hash FROM passwords WHERE user_id=?", | |
| (profile["id"],)).fetchone() | |
| db.close() | |
| if not pw_row or pw_row["hash"] != hash_password(pw): | |
| raise HTTPException(401, "Wrong username or password") | |
| if profile["is_banned"]: | |
| raise HTTPException(403, "Account banned") | |
| token = create_session(profile["id"]) | |
| return {"token": token, "profile": dict(profile)} | |
| async def logout(request: Request): | |
| auth = request.headers.get("Authorization","") | |
| if auth.startswith("Bearer "): | |
| db = get_db() | |
| db.execute("DELETE FROM sessions WHERE token=?", (auth[7:],)) | |
| db.commit() | |
| db.close() | |
| return {"ok": True} | |
| async def get_me(user=Depends(require_user)): | |
| if user.get("is_banned"): | |
| raise HTTPException(403, "Account banned") | |
| return {"profile": user} | |
| # ββ PROFILES ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def profile_stats(uid: str): | |
| db = get_db() | |
| p = db.execute("SELECT * FROM profiles WHERE uid=?", (uid,)).fetchone() | |
| if not p: raise HTTPException(404, "User not found") | |
| posts = db.execute("SELECT COUNT(*) FROM posts WHERE author_id=? AND is_deleted=0", (p["id"],)).fetchone()[0] | |
| comments = db.execute("SELECT COUNT(*) FROM comments WHERE author_id=?", (p["id"],)).fetchone()[0] | |
| db.close() | |
| return {"postCount": posts, "commentCount": comments, "points": posts*50 + comments*10, "role": p["role"]} | |
| async def profile_posts(uid: str): | |
| db = get_db() | |
| rows = db.execute(""" | |
| SELECT p.*, pr.uid as author_uid, pr.name as author_name, | |
| pr.initials as author_initials, pr.color as author_color, | |
| (SELECT COUNT(*) FROM comments c WHERE c.post_id=p.id) as comment_count | |
| FROM posts p JOIN profiles pr ON pr.id=p.author_id | |
| WHERE pr.uid=? AND p.is_deleted=0 ORDER BY p.created_at DESC | |
| """, (uid,)).fetchall() | |
| db.close() | |
| return [dict(r) for r in rows] | |
| # ββ POSTS βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def get_posts(): | |
| db = get_db() | |
| rows = db.execute(""" | |
| SELECT p.*, pr.uid as author_uid, pr.name as author_name, | |
| pr.initials as author_initials, pr.color as author_color, | |
| pr.role as author_role, pr.is_banned as author_banned, | |
| (SELECT COUNT(*) FROM comments c WHERE c.post_id=p.id) as comment_count | |
| FROM posts p JOIN profiles pr ON pr.id=p.author_id | |
| WHERE p.is_deleted=0 ORDER BY p.created_at DESC | |
| """).fetchall() | |
| db.close() | |
| return [dict(r) for r in rows] | |
| async def get_posts_since(ts: str = ""): | |
| db = get_db() | |
| sql = ("SELECT p.*, pr.uid as author_uid, pr.name as author_name," | |
| "pr.initials as author_initials, pr.color as author_color," | |
| "pr.role as author_role, pr.is_banned as author_banned," | |
| "(SELECT COUNT(*) FROM comments c WHERE c.post_id=p.id) as comment_count " | |
| "FROM posts p JOIN profiles pr ON pr.id=p.author_id " | |
| "WHERE p.is_deleted=0 AND p.created_at > ? ORDER BY p.created_at DESC") | |
| rows = db.execute(sql, (ts,)).fetchall() | |
| db.close() | |
| return [dict(r) for r in rows] | |
| async def stream(request: Request, channel: str = "all"): | |
| """SSE endpoint. channel='all' for posts, 'dm:{uid}' for DM notifications.""" | |
| q = broker.subscribe(channel) | |
| async def event_generator(): | |
| try: | |
| # Send a heartbeat immediately to confirm connection | |
| yield "data: {\"type\":\"connected\"}\n\n" | |
| while True: | |
| if await request.is_disconnected(): | |
| break | |
| try: | |
| msg = await asyncio.wait_for(q.get(), timeout=25.0) | |
| yield msg | |
| except asyncio.TimeoutError: | |
| yield ": keepalive\n\n" # SSE comment = keepalive | |
| finally: | |
| broker.unsubscribe(channel, q) | |
| return StreamingResponse(event_generator(), media_type="text/event-stream", | |
| headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}) | |
| async def create_post(request: Request, user=Depends(require_user)): | |
| data = await request.json() | |
| pid = str(uuid.uuid4()) | |
| db = get_db() | |
| db.execute( | |
| "INSERT INTO posts (id,author_id,title,description,location,category,status,image_url) VALUES (?,?,?,?,?,?,?,?)", | |
| (pid, user["id"], data["title"], data.get("description",""), | |
| data["location"], data["category"], data.get("status","found"), data.get("image_url")) | |
| ) | |
| db.commit() | |
| row = db.execute(""" | |
| SELECT p.*, pr.uid as author_uid, pr.name as author_name, | |
| pr.initials as author_initials, pr.color as author_color, 0 as comment_count | |
| FROM posts p JOIN profiles pr ON pr.id=p.author_id WHERE p.id=? | |
| """, (pid,)).fetchone() | |
| db.close() | |
| post_data = dict(row) | |
| broker.publish("all", {"type": "new_post", "post": post_data}) | |
| return post_data | |
| async def update_post(post_id: str, request: Request, user=Depends(require_user)): | |
| data = await request.json() | |
| db = get_db() | |
| post = db.execute("SELECT * FROM posts WHERE id=?", (post_id,)).fetchone() | |
| if not post: raise HTTPException(404) | |
| if post["author_id"] != user["id"] and user["role"] not in ("admin","super_admin"): | |
| raise HTTPException(403) | |
| fields = {k:v for k,v in data.items() if k in ("title","description","status","image_url","is_deleted")} | |
| fields["updated_at"] = datetime.utcnow().isoformat() | |
| sets = ", ".join(f"{k}=?" for k in fields) | |
| db.execute(f"UPDATE posts SET {sets} WHERE id=?", (*fields.values(), post_id)) | |
| db.commit() | |
| db.close() | |
| return {"ok": True} | |
| async def delete_post(post_id: str, user=Depends(require_user)): | |
| db = get_db() | |
| post = db.execute("SELECT p.*, pr.role as author_role FROM posts p JOIN profiles pr ON pr.id=p.author_id WHERE p.id=?", (post_id,)).fetchone() | |
| if not post: raise HTTPException(404) | |
| is_own = post["author_id"] == user["id"] | |
| is_admin = user["role"] in ("admin","super_admin") | |
| is_super = user["role"] == "super_admin" | |
| author_is_admin = post["author_role"] in ("admin","super_admin") | |
| author_is_super = post["author_role"] == "super_admin" | |
| # owner can always delete own post | |
| if not is_own and not is_admin: raise HTTPException(403) | |
| # admin can't delete other admin or super posts β only super can | |
| if not is_own and author_is_admin and not is_super: | |
| raise HTTPException(403, "Only super admin can delete admin/super posts") | |
| db.execute("UPDATE posts SET is_deleted=1 WHERE id=?", (post_id,)) | |
| db.commit() | |
| db.close() | |
| return {"ok": True} | |
| # ββ COMMENTS ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def get_comments(post_id: str, request: Request): | |
| db = get_db() | |
| # Try to identify current user for my_vote | |
| user_id = None | |
| auth = request.headers.get("Authorization", "") | |
| if auth.startswith("Bearer "): | |
| tok = auth[7:] | |
| row = db.execute("SELECT user_id FROM sessions WHERE token=?", (tok,)).fetchone() | |
| if row: user_id = row["user_id"] | |
| rows = db.execute(""" | |
| SELECT c.id, c.post_id, c.author_id, c.parent_id, c.body, c.image_url, c.created_at, | |
| pr.uid as author_uid, pr.name as author_name, | |
| pr.initials as author_initials, pr.color as author_color, | |
| COALESCE(SUM(CASE WHEN v.vote=1 THEN 1 ELSE 0 END),0) as upvotes, | |
| COALESCE(SUM(CASE WHEN v.vote=-1 THEN 1 ELSE 0 END),0) as downvotes, | |
| COALESCE(SUM(v.vote),0) as net_votes | |
| FROM comments c | |
| JOIN profiles pr ON pr.id=c.author_id | |
| LEFT JOIN comment_votes v ON v.comment_id=c.id | |
| WHERE c.post_id=? | |
| GROUP BY c.id | |
| ORDER BY c.created_at ASC | |
| """, (post_id,)).fetchall() | |
| my_votes = {} | |
| if user_id: | |
| mv = db.execute( | |
| "SELECT comment_id, vote FROM comment_votes WHERE user_id=?", (user_id,) | |
| ).fetchall() | |
| my_votes = {r["comment_id"]: r["vote"] for r in mv} | |
| db.close() | |
| result = [] | |
| for r in rows: | |
| d = dict(r) | |
| d["author"] = {"uid": d.pop("author_uid"), "name": d.pop("author_name"), | |
| "initials": d.pop("author_initials"), "color": d.pop("author_color")} | |
| d["my_vote"] = my_votes.get(d["id"], 0) | |
| result.append(d) | |
| return result | |
| async def create_comment(post_id: str, request: Request, user=Depends(require_user)): | |
| data = await request.json() | |
| cid = str(uuid.uuid4()) | |
| db = get_db() | |
| db.execute( | |
| "INSERT INTO comments (id,post_id,author_id,parent_id,body,image_url) VALUES (?,?,?,?,?,?)", | |
| (cid, post_id, user["id"], data.get("parent_id"), data.get("body",""), data.get("image_url")) | |
| ) | |
| db.commit() | |
| row = db.execute(""" | |
| SELECT c.*, pr.uid as author_uid, pr.name as author_name, | |
| pr.initials as author_initials, pr.color as author_color | |
| FROM comments c JOIN profiles pr ON pr.id=c.author_id WHERE c.id=? | |
| """, (cid,)).fetchone() | |
| db.close() | |
| d = dict(row) | |
| d["author"] = {"uid": d.pop("author_uid"), "name": d.pop("author_name"), | |
| "initials": d.pop("author_initials"), "color": d.pop("author_color")} | |
| d["net_votes"] = 0; d["my_vote"] = 0; d["upvotes"] = 0; d["downvotes"] = 0 | |
| # Broadcast to everyone viewing this post | |
| broker.publish(f"post:{post_id}", {"type": "new_comment", "comment": d}) | |
| return d | |
| # ββ IMAGE UPLOAD ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def edit_comment(comment_id: str, request: Request, user=Depends(require_user)): | |
| comment = get_db().execute("SELECT * FROM comments WHERE id=?", (comment_id,)).fetchone() | |
| if not comment: raise HTTPException(404) | |
| if comment["author_id"] != user["id"]: raise HTTPException(403, "Only the author can edit") | |
| data = await request.json() | |
| body = data.get("body","").strip() | |
| if not body: raise HTTPException(400, "Body required") | |
| db = get_db() | |
| db.execute("UPDATE comments SET body=? WHERE id=?", (body, comment_id)) | |
| db.commit(); db.close() | |
| return {"ok": True, "body": body} | |
| async def report_comment(comment_id: str, request: Request, user=Depends(require_user)): | |
| db = get_db() | |
| comment = db.execute("SELECT * FROM comments WHERE id=?", (comment_id,)).fetchone() | |
| if not comment: raise HTTPException(404) | |
| if comment["author_id"] == user["id"]: raise HTTPException(403, "Cannot report own comment") | |
| if user["role"] in ("admin","super_admin"): raise HTTPException(403, "Admins cannot report") | |
| data = await request.json() | |
| db.execute( | |
| "INSERT OR IGNORE INTO reports (id,post_id,reporter_id,reason) VALUES (?,?,?,?)", | |
| (str(uuid.uuid4()), comment["post_id"], user["id"], f'comment:{comment_id}:{data.get("reason","")}') | |
| ) | |
| db.commit(); db.close() | |
| return {"ok": True} | |
| async def delete_comment(comment_id: str, user=Depends(require_user)): | |
| db = get_db() | |
| comment = db.execute("SELECT c.*, pr.role as author_role FROM comments c JOIN profiles pr ON pr.id=c.author_id WHERE c.id=?", (comment_id,)).fetchone() | |
| if not comment: raise HTTPException(404) | |
| is_own = comment["author_id"] == user["id"] | |
| is_admin = user["role"] in ("admin","super_admin") | |
| if not is_own and not is_admin: | |
| raise HTTPException(403, "Cannot delete this comment") | |
| db.execute("DELETE FROM comments WHERE id=?", (comment_id,)) | |
| db.commit() | |
| db.close() | |
| return {"ok": True} | |
| async def vote_comment(comment_id: str, request: Request, user=Depends(require_user)): | |
| data = await request.json() | |
| vote = int(data.get("vote", 0)) | |
| db = get_db() | |
| if vote == 0: | |
| db.execute("DELETE FROM comment_votes WHERE comment_id=? AND user_id=?", (comment_id, user["id"])) | |
| else: | |
| sql = ("INSERT INTO comment_votes (id,comment_id,user_id,vote) VALUES (?,?,?,?) " | |
| "ON CONFLICT(comment_id,user_id) DO UPDATE SET vote=excluded.vote") | |
| db.execute(sql, (str(uuid.uuid4()), comment_id, user["id"], vote)) | |
| row = db.execute( | |
| "SELECT COALESCE(SUM(CASE WHEN vote=1 THEN 1 ELSE 0 END),0) as upvotes," | |
| "COALESCE(SUM(CASE WHEN vote=-1 THEN 1 ELSE 0 END),0) as downvotes," | |
| "COALESCE(SUM(vote),0) as net_votes FROM comment_votes WHERE comment_id=?", | |
| (comment_id,) | |
| ).fetchone() | |
| db.commit(); db.close() | |
| return {"ok":True,"upvotes":row["upvotes"],"downvotes":row["downvotes"],"net_votes":row["net_votes"],"my_vote":vote} | |
| async def upload_image(file: UploadFile = File(...), user=Depends(require_user)): | |
| if not file.content_type.startswith("image/"): | |
| raise HTTPException(400, "Not an image") | |
| ext = file.content_type.split("/")[1].replace("jpeg","jpg") | |
| filename = f"{uuid.uuid4()}.{ext}" | |
| with open(os.path.join(IMG_DIR, filename), "wb") as f: | |
| shutil.copyfileobj(file.file, f) | |
| return {"url": f"/images/{filename}"} | |
| # ββ DIRECT MESSAGES ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def get_conversations(user=Depends(require_user)): | |
| """Get all conversations for the current user, with latest message and unread count.""" | |
| db = get_db() | |
| rows = db.execute(""" | |
| SELECT | |
| p.id, p.uid, p.name, p.initials, p.color, | |
| (SELECT body FROM dms WHERE | |
| (sender_id=? AND receiver_id=p.id) OR | |
| (sender_id=p.id AND receiver_id=?) | |
| ORDER BY created_at DESC LIMIT 1) as last_msg, | |
| (SELECT created_at FROM dms WHERE | |
| (sender_id=? AND receiver_id=p.id) OR | |
| (sender_id=p.id AND receiver_id=?) | |
| ORDER BY created_at DESC LIMIT 1) as last_at, | |
| (SELECT COUNT(*) FROM dms WHERE sender_id=p.id AND receiver_id=? AND read=0) as unread | |
| FROM profiles p | |
| WHERE p.id != ? | |
| AND EXISTS ( | |
| SELECT 1 FROM dms WHERE | |
| (sender_id=? AND receiver_id=p.id) OR | |
| (sender_id=p.id AND receiver_id=?) | |
| ) | |
| ORDER BY last_at DESC | |
| """, (user["id"],)*8).fetchall() | |
| db.close() | |
| return [dict(r) for r in rows] | |
| async def get_dm_thread(other_uid: str, user=Depends(require_user)): | |
| """Get all messages between current user and another user.""" | |
| db = get_db() | |
| other = db.execute("SELECT * FROM profiles WHERE uid=?", (other_uid,)).fetchone() | |
| if not other: raise HTTPException(404, "User not found") | |
| other = dict(other) | |
| msgs = db.execute(""" | |
| SELECT d.*, p.uid as sender_uid, p.initials as sender_initials, p.color as sender_color | |
| FROM dms d JOIN profiles p ON p.id=d.sender_id | |
| WHERE (d.sender_id=? AND d.receiver_id=?) OR (d.sender_id=? AND d.receiver_id=?) | |
| ORDER BY d.created_at ASC | |
| """, (user["id"], other["id"], other["id"], user["id"])).fetchall() | |
| # Mark received messages as read | |
| db.execute("UPDATE dms SET read=1 WHERE sender_id=? AND receiver_id=? AND read=0", | |
| (other["id"], user["id"])) | |
| db.commit() | |
| db.close() | |
| return {"other": other, "messages": [dict(m) for m in msgs]} | |
| async def send_dm(other_uid: str, request: Request, user=Depends(require_user)): | |
| """Send a message to another user.""" | |
| db = get_db() | |
| other = db.execute("SELECT * FROM profiles WHERE uid=?", (other_uid,)).fetchone() | |
| if not other: raise HTTPException(404, "User not found") | |
| data = await request.json() | |
| body = data.get("body","").strip() | |
| image_url = data.get("image_url") | |
| if not body and not image_url: raise HTTPException(400, "Empty message") | |
| mid = str(uuid.uuid4()) | |
| db.execute("INSERT INTO dms (id,sender_id,receiver_id,body,image_url) VALUES (?,?,?,?,?)", | |
| (mid, user["id"], other["id"], body, image_url)) | |
| db.commit() | |
| db.close() | |
| msg_data = {"id": mid, "sender_uid": user["uid"], "body": body, "image_url": image_url, | |
| "created_at": datetime.utcnow().isoformat(), "read": 0} | |
| # Notify recipient instantly via SSE | |
| broker.publish(f"dm:{other_uid}", {"type": "new_dm", "from_uid": user["uid"], "msg": msg_data}) | |
| return msg_data | |
| async def unread_count(user=Depends(require_user)): | |
| db = get_db() | |
| count = db.execute("SELECT COUNT(*) FROM dms WHERE receiver_id=? AND read=0", (user["id"],)).fetchone()[0] | |
| db.close() | |
| return {"count": count} | |
| # ββ ALERTS βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def send_alert(target_uid: str, request: Request, user=Depends(require_admin)): | |
| target = get_db().execute("SELECT * FROM profiles WHERE uid=?", (target_uid,)).fetchone() | |
| if not target: raise HTTPException(404) | |
| if target["role"] == "super_admin": raise HTTPException(403, "Cannot alert super admin") | |
| if target["role"] == "admin" and user["role"] != "super_admin": raise HTTPException(403, "Only super admin can alert admins") | |
| data = await request.json() | |
| db = get_db() | |
| # Count active alerts | |
| now = datetime.utcnow().isoformat() | |
| count = db.execute("SELECT COUNT(*) FROM alerts WHERE user_id=? AND expires_at>?", (target["id"], now)).fetchone()[0] | |
| mid = str(uuid.uuid4()) | |
| expiry = datetime.utcnow().replace(day=datetime.utcnow().day).isoformat() # placeholder | |
| from datetime import timedelta | |
| expiry = (datetime.utcnow() + timedelta(days=30)).isoformat() | |
| db.execute("INSERT INTO alerts (id,user_id,admin_id,note,expires_at) VALUES (?,?,?,?,?)", | |
| (mid, target["id"], user["id"], data.get("note",""), expiry)) | |
| # Auto-ban after 5 active alerts | |
| new_count = count + 1 | |
| if new_count >= 5: | |
| db.execute("UPDATE profiles SET is_banned=1 WHERE id=?", (target["id"],)) | |
| db.commit() | |
| db.close() | |
| return {"ok": True, "auto_banned": new_count >= 5} | |
| async def get_alerts(uid: str, user=Depends(get_current_user)): | |
| db = get_db() | |
| target = db.execute("SELECT * FROM profiles WHERE uid=?", (uid,)).fetchone() | |
| if not target: raise HTTPException(404) | |
| now = datetime.utcnow().isoformat() | |
| rows = db.execute(""" | |
| SELECT a.*, p.uid as admin_uid FROM alerts a | |
| JOIN profiles p ON p.id=a.admin_id | |
| WHERE a.user_id=? AND a.expires_at>? | |
| ORDER BY a.created_at DESC | |
| """, (target["id"], now)).fetchall() | |
| db.close() | |
| return {"count": len(rows), "alerts": [dict(r) for r in rows]} | |
| # ββ REPORTS ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def report_post(post_id: str, request: Request, user=Depends(require_user)): | |
| if user["role"] in ("admin","super_admin"): raise HTTPException(403, "Admins cannot report posts") | |
| db = get_db() | |
| post = db.execute("SELECT * FROM posts WHERE id=?", (post_id,)).fetchone() | |
| if not post: raise HTTPException(404) | |
| if post["author_id"] == user["id"]: raise HTTPException(403, "Cannot report your own post") | |
| data = await request.json() | |
| try: | |
| db.execute("INSERT INTO reports (id,post_id,reporter_id,reason) VALUES (?,?,?,?)", | |
| (str(uuid.uuid4()), post_id, user["id"], data.get("reason",""))) | |
| db.commit() | |
| except: pass # UNIQUE constraint - already reported | |
| db.close() | |
| return {"ok": True} | |
| async def delete_report(post_id: str, user=Depends(require_admin)): | |
| db = get_db() | |
| db.execute("DELETE FROM reports WHERE post_id=?", (post_id,)) | |
| db.commit() | |
| db.close() | |
| return {"ok": True} | |
| async def get_reports(user=Depends(require_admin)): | |
| db = get_db() | |
| rows = db.execute(""" | |
| SELECT p.id, p.title, p.description, p.location, p.category, p.status, | |
| p.author_id, p.image_url, p.created_at, | |
| pr.uid as author_uid, pr.name as author_name, | |
| pr.initials as author_initials, pr.color as author_color, | |
| COUNT(r.id) as report_count, | |
| GROUP_CONCAT(r.reason, ', ') as reasons | |
| FROM reports r | |
| JOIN posts p ON p.id=r.post_id | |
| JOIN profiles pr ON pr.id=p.author_id | |
| WHERE p.is_deleted=0 | |
| GROUP BY p.id ORDER BY report_count DESC | |
| """).fetchall() | |
| db.close() | |
| return [dict(r) for r in rows] | |
| # ββ MOD LOG ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def create_log(request: Request, user=Depends(require_admin)): | |
| data = await request.json() | |
| db = get_db() | |
| db.execute( | |
| "INSERT INTO mod_log (id,admin_id,action,target_id,post_id,note) VALUES (?,?,?,?,?,?)", | |
| (str(uuid.uuid4()), user["id"], data.get("action",""), | |
| data.get("target_id"), data.get("post_id"), data.get("note","")) | |
| ) | |
| db.commit() | |
| db.close() | |
| broker.publish("admin", {"type": "new_log"}) | |
| return {"ok": True} | |
| async def get_log(user=Depends(require_admin)): | |
| db = get_db() | |
| sql = """ | |
| SELECT l.*, | |
| a.uid as admin_uid, a.name as admin_name, | |
| t.uid as target_uid, t.name as target_name, | |
| p.title as post_title | |
| FROM mod_log l | |
| JOIN profiles a ON a.id=l.admin_id | |
| LEFT JOIN profiles t ON t.id=l.target_id | |
| LEFT JOIN posts p ON p.id=l.post_id | |
| ORDER BY l.created_at DESC LIMIT 200 | |
| """ | |
| rows = db.execute(sql).fetchall() | |
| db.close() | |
| return [dict(r) for r in rows] | |
| # ββ ADMIN βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def admin_stats(user=Depends(require_admin)): | |
| db = get_db() | |
| posts = db.execute("SELECT COUNT(*) FROM posts WHERE is_deleted=0").fetchone()[0] | |
| active = db.execute("SELECT COUNT(*) FROM posts WHERE is_deleted=0 AND status!='recovered'").fetchone()[0] | |
| users = db.execute("SELECT COUNT(*) FROM profiles").fetchone()[0] | |
| admins = db.execute("SELECT COUNT(*) FROM profiles WHERE role IN ('admin','super_admin')").fetchone()[0] | |
| banned = db.execute("SELECT COUNT(*) FROM profiles WHERE is_banned=1").fetchone()[0] | |
| pend = db.execute("SELECT COUNT(*) FROM admin_requests WHERE status='pending'").fetchone()[0] | |
| reports= db.execute("SELECT COUNT(DISTINCT post_id) FROM reports").fetchone()[0] | |
| db.close() | |
| return {"totalPosts":posts,"activePosts":active,"totalUsers":users,"admins":admins,"bannedUsers":banned,"pendingRequests":pend,"reportedPosts":reports} | |
| async def get_admin_requests(user=Depends(require_admin)): | |
| db = get_db() | |
| rows = db.execute(""" | |
| SELECT r.*, p.uid as requester_uid | |
| FROM admin_requests r | |
| LEFT JOIN profiles p ON p.id=r.user_id | |
| WHERE r.status='pending' | |
| ORDER BY r.created_at DESC | |
| """).fetchall() | |
| db.close() | |
| return [dict(r) for r in rows] | |
| async def review_request(req_id: str, request: Request, user=Depends(require_admin)): | |
| data = await request.json() | |
| status = data.get("status","rejected") | |
| db = get_db() | |
| db.execute("UPDATE admin_requests SET status=? WHERE id=?", (status, req_id)) | |
| db.commit() | |
| db.close() | |
| broker.publish("admin", {"type": "request_reviewed", "req_id": req_id, "status": status}) | |
| return {"ok": True} | |
| async def get_all_users(user=Depends(require_admin)): | |
| db = get_db() | |
| rows = db.execute("SELECT * FROM profiles ORDER BY created_at ASC").fetchall() | |
| # Compute points dynamically: posts*50 + comments*10 | |
| result = [] | |
| for r in rows: | |
| d = dict(r) | |
| post_count = db.execute("SELECT COUNT(*) FROM posts WHERE author_id=? AND is_deleted=0", (d["id"],)).fetchone()[0] | |
| comment_count = db.execute("SELECT COUNT(*) FROM comments WHERE author_id=?", (d["id"],)).fetchone()[0] | |
| vote_row = db.execute( | |
| "SELECT COALESCE(SUM(cv.vote),0) as vs FROM comment_votes cv " | |
| "JOIN comments c ON c.id=cv.comment_id WHERE c.author_id=?", (d["id"],) | |
| ).fetchone() | |
| d["points"] = post_count * 50 + comment_count * 10 | |
| d["vote_score"] = int(vote_row["vs"]) if vote_row else 0 | |
| result.append(d) | |
| db.close() | |
| return result | |
| async def admin_update_profile(user_id: str, request: Request, user=Depends(require_admin)): | |
| data = await request.json() | |
| db = get_db() | |
| fields = {k:v for k,v in data.items() if k in ("role","is_banned","points")} | |
| if fields: | |
| sets = ", ".join(f"{k}=?" for k in fields) | |
| db.execute(f"UPDATE profiles SET {sets} WHERE id=?", (*fields.values(), user_id)) | |
| # If unbanning, delete all their alerts so count resets to 0 | |
| if fields.get("is_banned") == 0: | |
| db.execute("DELETE FROM alerts WHERE user_id=?", (user_id,)) | |
| db.commit() | |
| db.close() | |
| return {"ok": True} | |
| async def verify_id_and_grant( | |
| file: UploadFile = File(...), | |
| uid: str = Form(""), | |
| user=Depends(require_user), | |
| ): | |
| """ | |
| One-shot ID verification: receive image, run Florence OCR, grant admin if valid. | |
| No URL fetching needed β image bytes arrive directly. | |
| """ | |
| img_bytes = await file.read() | |
| db = get_db() | |
| try: | |
| id_result = _siglip_check_id(img_bytes) | |
| print(f"[verify-id] uid={uid} is_id={id_result.get('is_id')} conf={id_result.get('confidence',0):.2f} ocr={repr(id_result.get('ocr','')[:80])}") | |
| if not id_result.get("is_id"): | |
| # Save image anyway for the pending request | |
| ext = file.filename.rsplit(".", 1)[-1] if "." in (file.filename or "") else "jpg" | |
| fname = f"id_{uuid.uuid4().hex[:8]}.{ext}" | |
| img_path = os.path.join(IMG_DIR, fname) | |
| with open(img_path, "wb") as f: | |
| f.write(img_bytes) | |
| return {"auto_approved": False, "confidence": id_result.get("confidence", 0), "id_image_url": f"/images/{fname}"} | |
| # Save the ID image for audit trail | |
| ext = file.filename.rsplit(".", 1)[-1] if "." in (file.filename or "") else "jpg" | |
| fname = f"id_{uuid.uuid4().hex[:8]}.{ext}" | |
| img_path = os.path.join(IMG_DIR, fname) | |
| with open(img_path, "wb") as f: | |
| f.write(img_bytes) | |
| id_image_url = f"/images/{fname}" | |
| # Grant admin role | |
| lookup_uid = uid.strip().lower() or user.get("uid", "") | |
| profile = db.execute("SELECT id FROM profiles WHERE uid=?", (lookup_uid,)).fetchone() | |
| if profile: | |
| db.execute("UPDATE profiles SET role='admin' WHERE id=?", (profile["id"],)) | |
| # Log the approved request | |
| db.execute( | |
| "INSERT INTO admin_requests (id,user_id,email,name,role_title,reason,id_image_url,status) VALUES (?,?,?,?,?,?,?,?)", | |
| (str(uuid.uuid4()), profile["id"] if profile else lookup_uid, | |
| lookup_uid, lookup_uid, "staff", "Auto-approved via ID scan", id_image_url, "approved") | |
| ) | |
| db.commit() | |
| print(f"[verify-id] β admin granted to uid={lookup_uid}") | |
| return {"auto_approved": True, "confidence": id_result.get("confidence", 0)} | |
| except Exception as e: | |
| print(f"[verify-id] error: {e}") | |
| return {"auto_approved": False, "error": str(e)} | |
| finally: | |
| db.close() | |
| async def submit_admin_request(request: Request): | |
| data = await request.json() | |
| db = get_db() | |
| uid = data.get("uid","").strip().lower() | |
| profile = db.execute("SELECT id FROM profiles WHERE uid=?", (uid,)).fetchone() if uid else None | |
| user_id = profile["id"] if profile else str(uuid.uuid4()) | |
| # ββ Auto ID check ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # If they uploaded an ID image, run SigLIP zero-shot check. | |
| # If it looks like a student/staff ID β auto-approve immediately. | |
| id_image_url = data.get("id_image_url") | |
| auto_approved = False | |
| if id_image_url and profile: | |
| try: | |
| img_bytes = None | |
| img_path = os.path.join(IMG_DIR, os.path.basename(id_image_url)) | |
| if os.path.exists(img_path): | |
| img_bytes = open(img_path, "rb").read() | |
| else: | |
| # File not on disk (Space restarted / stored remotely) β fetch from URL | |
| import urllib.request | |
| try: | |
| with urllib.request.urlopen(id_image_url, timeout=10) as resp: | |
| img_bytes = resp.read() | |
| print(f"[admin-auto] fetched ID from URL ({len(img_bytes)} bytes)") | |
| except Exception as fetch_err: | |
| print(f"[admin-auto] could not fetch URL: {fetch_err}") | |
| if img_bytes: | |
| id_result = _siglip_check_id(img_bytes) | |
| print(f"[admin-auto] is_id={id_result.get('is_id')} conf={id_result.get('confidence',0):.2f} ocr={repr(id_result.get('ocr','')[:80])}") | |
| if id_result.get("is_id") and id_result.get("confidence", 0) >= 0.45: | |
| db.execute("UPDATE profiles SET role='admin' WHERE id=?", (profile["id"],)) | |
| auto_approved = True | |
| print(f"[admin-auto] β approved uid={uid}") | |
| else: | |
| print(f"[admin-auto] β rejected uid={uid} β not recognized as ID card") | |
| except Exception as e: | |
| print(f"[admin-auto] check failed: {e}") | |
| if auto_approved: | |
| # Still log it but mark as auto-approved | |
| db.execute( | |
| "INSERT INTO admin_requests (id,user_id,email,name,role_title,reason,id_image_url,status) VALUES (?,?,?,?,?,?,?,?)", | |
| (str(uuid.uuid4()), user_id, data.get("email",""), data.get("name",""), | |
| data.get("role_title",""), data.get("reason",""), id_image_url, "approved") | |
| ) | |
| db.commit(); db.close() | |
| broker.publish("admin", {"type": "new_request"}) | |
| return {"ok": True, "auto_approved": True, "message": "ID verified β admin access granted!"} | |
| else: | |
| db.execute( | |
| "INSERT INTO admin_requests (id,user_id,email,name,role_title,reason,id_image_url) VALUES (?,?,?,?,?,?,?)", | |
| (str(uuid.uuid4()), user_id, data.get("email",""), data.get("name",""), | |
| data.get("role_title",""), data.get("reason",""), id_image_url) | |
| ) | |
| db.commit(); db.close() | |
| broker.publish("admin", {"type": "new_request"}) | |
| return {"ok": True, "auto_approved": False, "message": "Request submitted β pending review"} | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ββ AI IMAGE FEATURES βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Feature 1: POST /search/image β image search in search bar | |
| # Feature 2: POST /posts/with-image β create post + auto-match via SSE | |
| # Feature 3: NSFW check on upload β POST /upload/checked | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| import io, threading | |
| import numpy as np | |
| from PIL import Image | |
| # ββ lazy model singletons βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _dino_proc = _dino_model = None | |
| _nsfw_proc = _nsfw_model = None | |
| _model_lock = threading.Lock() | |
| _florence_lock = threading.Lock() | |
| _qwen_lock = threading.Lock() | |
| _siglip_lock = threading.Lock() | |
| def _load_dino(): | |
| global _dino_proc, _dino_model | |
| if _dino_model is None: | |
| with _model_lock: | |
| if _dino_model is None: | |
| from transformers import AutoImageProcessor, AutoModel | |
| _dino_proc = AutoImageProcessor.from_pretrained("facebook/dinov2-small") | |
| _dino_model = AutoModel.from_pretrained("facebook/dinov2-small") | |
| _dino_model.eval() | |
| return _dino_proc, _dino_model | |
| def _load_nsfw(): | |
| global _nsfw_proc, _nsfw_model | |
| if _nsfw_model is None: | |
| with _model_lock: | |
| if _nsfw_model is None: | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| _nsfw_proc = AutoImageProcessor.from_pretrained("Falconsai/nsfw_image_detection") | |
| _nsfw_model = AutoModelForImageClassification.from_pretrained("Falconsai/nsfw_image_detection") | |
| _nsfw_model.eval() | |
| return _nsfw_proc, _nsfw_model | |
| # ββ helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _embed(img_bytes: bytes) -> list: | |
| import torch | |
| proc, model = _load_dino() | |
| img = Image.open(io.BytesIO(img_bytes)).convert("RGB") | |
| inputs = proc(images=img, return_tensors="pt") | |
| with torch.no_grad(): | |
| out = model(**inputs) | |
| return out.last_hidden_state[:, 0, :].squeeze().numpy().tolist() | |
| def _embed_path(path: str) -> list: | |
| with open(path, "rb") as f: | |
| return _embed(f.read()) | |
| def _cosine(a, b) -> float: | |
| va, vb = np.array(a), np.array(b) | |
| d = np.linalg.norm(va) * np.linalg.norm(vb) | |
| return float(np.dot(va, vb) / d) if d > 0 else 0.0 | |
| def _is_nsfw(img_bytes: bytes) -> bool: | |
| import torch | |
| proc, model = _load_nsfw() | |
| img = Image.open(io.BytesIO(img_bytes)).convert("RGB") | |
| inputs = proc(images=img, return_tensors="pt") | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| label = model.config.id2label[int(logits.argmax(-1))] | |
| return label.lower() == "nsfw" | |
| # ββ DB migration: add embedding column ββββββββββββββββββββββββββββββββββββββββ | |
| def _migrate(): | |
| db = get_db() | |
| cols = [r[1] for r in db.execute("PRAGMA table_info(posts)").fetchall()] | |
| if "embedding" not in cols: | |
| db.execute("ALTER TABLE posts ADD COLUMN embedding TEXT") | |
| db.commit() | |
| db.close() | |
| _migrate() | |
| # ββ startup backfill (embed existing posts that have images) ββββββββββββββββββ | |
| def _backfill(): | |
| db = get_db() | |
| rows = db.execute( | |
| "SELECT id, image_url FROM posts WHERE image_url IS NOT NULL AND embedding IS NULL AND is_deleted=0" | |
| ).fetchall() | |
| db.close() | |
| for row in rows: | |
| # strip leading /images/ to get filename | |
| fname = row["image_url"].lstrip("/images/").lstrip("/") | |
| path = os.path.join(IMG_DIR, fname) | |
| if not os.path.exists(path): | |
| continue | |
| try: | |
| vec = _embed_path(path) | |
| db2 = get_db() | |
| db2.execute("UPDATE posts SET embedding=? WHERE id=?", (json.dumps(vec), row["id"])) | |
| db2.commit() | |
| db2.close() | |
| except Exception as e: | |
| print(f"[backfill] {row['id']}: {e}") | |
| threading.Thread(target=_backfill, daemon=True).start() | |
| # pending embeddings keyed by image url (temp store between upload & post save) | |
| _pending_emb: dict[str, list] = {} | |
| _pending_siglip: dict[str, list] = {} | |
| # ββ FEATURE 3: moderated image upload βββββββββββββββββββββββββββββββββββββββββ | |
| async def upload_checked(file: UploadFile = File(...), user=Depends(require_user)): | |
| """Upload image with NSFW check + DINOv2 embedding (pre-computed before post save).""" | |
| if not file.content_type.startswith("image/"): | |
| raise HTTPException(400, "File must be an image") | |
| data = await file.read() | |
| # ββ NSFW gate βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| try: | |
| if _is_nsfw(data): | |
| raise HTTPException(422, "Image rejected: inappropriate content detected") | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| print(f"[nsfw] {e}") # don't block upload on model error | |
| # ββ save βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| ext = file.content_type.split("/")[1].replace("jpeg", "jpg") | |
| filename = f"{uuid.uuid4()}.{ext}" | |
| fpath = os.path.join(IMG_DIR, filename) | |
| with open(fpath, "wb") as f: | |
| f.write(data) | |
| url = f"/images/{filename}" | |
| # ββ embed synchronously β must be ready before /posts/ai is called βββββββββ | |
| # The user is already waiting for the upload response; 150-200ms extra is fine. | |
| try: | |
| _pending_emb[url] = _embed_path(fpath) | |
| except Exception as e: | |
| print(f"[embed upload] {e}") # non-fatal: post saves without embedding | |
| return {"url": url} | |
| # ββ FEATURE 2: create post + auto-match βββββββββββββββββββββββββββββββββββββββ | |
| async def create_post_ai(request: Request, user=Depends(require_user)): | |
| """ | |
| Same as POST /posts but: | |
| - Saves pre-computed embedding from _pending_emb | |
| - Triggers auto-match SSE for LostβFound after save | |
| """ | |
| data = await request.json() | |
| pid = str(uuid.uuid4()) | |
| db = get_db() | |
| iurl = data.get("image_url") | |
| emb = json.dumps(_pending_emb.pop(iurl)) if iurl and iurl in _pending_emb else None | |
| siglip = json.dumps(_pending_siglip.pop(iurl)) if iurl and iurl in _pending_siglip else None | |
| db.execute( | |
| "INSERT INTO posts (id,author_id,title,description,location,category,status,image_url,embedding,siglip_embedding) " | |
| "VALUES (?,?,?,?,?,?,?,?,?,?)", | |
| (pid, user["id"], data["title"], data.get("description",""), | |
| data["location"], data["category"], data.get("status","found"), iurl, emb, siglip) | |
| ) | |
| db.commit() | |
| row = db.execute(""" | |
| SELECT p.*, pr.uid as author_uid, pr.name as author_name, | |
| pr.initials as author_initials, pr.color as author_color, 0 as comment_count | |
| FROM posts p JOIN profiles pr ON pr.id=p.author_id WHERE p.id=? | |
| """, (pid,)).fetchone() | |
| db.close() | |
| post_data = dict(row) | |
| broker.publish("all", {"type": "new_post", "post": post_data}) | |
| # ββ auto-match: search opposite-status posts by image ββββββββββββββββββββ | |
| status = data.get("status","found") | |
| if emb and status in ("lost", "found"): | |
| opposite = "found" if status == "lost" else "lost" | |
| def _match(post_id, emb_json, opp, author_uid): | |
| try: | |
| qvec = json.loads(emb_json) | |
| db2 = get_db() | |
| rows = db2.execute( | |
| "SELECT id, title, image_url, embedding FROM posts " | |
| "WHERE status=? AND is_deleted=0 AND embedding IS NOT NULL AND id!=?", | |
| (opp, post_id) | |
| ).fetchall() | |
| db2.close() | |
| scored = [] | |
| for r in rows: | |
| try: | |
| sim = _cosine(qvec, json.loads(r["embedding"])) | |
| if sim > 0.5: | |
| scored.append({"id": r["id"], "title": r["title"], | |
| "image_url": r["image_url"], "score": round(sim,3)}) | |
| except Exception: | |
| pass | |
| scored.sort(key=lambda x: x["score"], reverse=True) | |
| top5 = scored[:5] | |
| if top5: | |
| broker.publish(f"user:{author_uid}", { | |
| "type": "image_matches", "post_id": post_id, "matches": top5 | |
| }) | |
| except Exception as e: | |
| print(f"[auto-match] {e}") | |
| threading.Thread(target=_match, args=(pid, emb, opposite, user["uid"]), daemon=True).start() | |
| return post_data | |
| # ββ FEATURE 1: image search endpoint ββββββββββββββββββββββββββββββββββββββββββ | |
| async def search_by_image( | |
| file: UploadFile = File(...), | |
| status_filter: str = "all" | |
| ): | |
| """Search posts by image similarity. status_filter: all|lost|found|waiting|recovered""" | |
| if not file.content_type.startswith("image/"): | |
| raise HTTPException(400, "File must be an image") | |
| try: | |
| qvec = _embed(await file.read()) | |
| except Exception as e: | |
| raise HTTPException(500, f"Could not process image: {e}") | |
| db = get_db() | |
| sql = ( | |
| "SELECT p.*, pr.uid as author_uid, pr.name as author_name, " | |
| "pr.initials as author_initials, pr.color as author_color, 0 as comment_count " | |
| "FROM posts p JOIN profiles pr ON pr.id=p.author_id " | |
| "WHERE p.is_deleted=0 AND p.embedding IS NOT NULL" | |
| ) | |
| params = [] | |
| if status_filter != "all": | |
| sql += " AND p.status=?"; params.append(status_filter) | |
| rows = db.execute(sql, params).fetchall() | |
| db.close() | |
| scored = [] | |
| for r in rows: | |
| try: | |
| sim = _cosine(qvec, json.loads(r["embedding"])) | |
| if sim > 0.2: # low threshold β real photos of same item score ~0.7-0.9, unrelated ~0.1-0.3 | |
| d = dict(r); d["similarity"] = round(sim,3); d.pop("embedding",None) | |
| scored.append(d) | |
| except Exception: | |
| pass | |
| scored.sort(key=lambda x: x["similarity"], reverse=True) | |
| return scored[:10] | |
| # ββ DEBUG: check embedding status βββββββββββββββββββββββββββββββββββββββββββββ | |
| async def debug_embeddings(): | |
| """Shows which posts have embeddings stored. Remove this endpoint in production.""" | |
| db = get_db() | |
| rows = db.execute( | |
| "SELECT id, title, status, image_url, " | |
| "CASE WHEN embedding IS NULL THEN 0 ELSE 1 END as has_embedding " | |
| "FROM posts WHERE is_deleted=0 ORDER BY created_at DESC LIMIT 50" | |
| ).fetchall() | |
| db.close() | |
| total = len(rows) | |
| with_emb = sum(1 for r in rows if r["has_embedding"]) | |
| return { | |
| "total_posts": total, | |
| "posts_with_embedding": with_emb, | |
| "posts_without_embedding": total - with_emb, | |
| "posts": [dict(r) for r in rows], | |
| } | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ββ NEW AI FEATURES βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # F-A: Florence-2 β photo auto-fill (title, desc, category) | |
| # F-B: Qwen-0.5B β natural language search parsing | |
| # F-C: SigLIP2 β live camera search + admin ID check | |
| # F-D: DINOv3 β upgraded image similarity (replaces DINOv2 gradually) | |
| # F-E: Cron β auto-status nudge for stale posts | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ββ lazy model singletons (new) βββββββββββββββββββββββββββββββββββββββββββββββ | |
| _florence_proc = _florence_model = None | |
| _qwen_tok = _qwen_model = None | |
| _siglip_proc = _siglip_model = None | |
| def _load_florence(): | |
| global _florence_proc, _florence_model | |
| if _florence_model is None: | |
| with _florence_lock: | |
| if _florence_model is None: | |
| import sys, types, torch | |
| from transformers import AutoProcessor, AutoModelForCausalLM | |
| # Florence-2's modeling file calls is_flash_attn_2_available() | |
| # which uses importlib.util.find_spec β stub needs __spec__ set | |
| if "flash_attn" not in sys.modules: | |
| import importlib.util | |
| stub = types.ModuleType("flash_attn") | |
| stub.__spec__ = importlib.util.spec_from_loader("flash_attn", loader=None) | |
| stub.__version__ = "0.0.0" | |
| stub.flash_attn_func = None | |
| stub.flash_attn_varlen_func = None | |
| sys.modules["flash_attn"] = stub | |
| sub = types.ModuleType("flash_attn.flash_attn_interface") | |
| sub.__spec__ = importlib.util.spec_from_loader("flash_attn.flash_attn_interface", loader=None) | |
| sys.modules["flash_attn.flash_attn_interface"] = sub | |
| _florence_proc = AutoProcessor.from_pretrained( | |
| "microsoft/Florence-2-base", trust_remote_code=True) | |
| _florence_model = AutoModelForCausalLM.from_pretrained( | |
| "microsoft/Florence-2-base", trust_remote_code=True, | |
| attn_implementation="eager", | |
| torch_dtype=torch.float32, | |
| ) | |
| _florence_model.eval() | |
| return _florence_proc, _florence_model | |
| def _load_qwen(): | |
| global _qwen_tok, _qwen_model | |
| if _qwen_model is None: | |
| with _qwen_lock: | |
| if _qwen_model is None: | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| _qwen_tok = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") | |
| _qwen_model = AutoModelForCausalLM.from_pretrained( | |
| "Qwen/Qwen2.5-0.5B-Instruct") | |
| _qwen_model.eval() | |
| return _qwen_tok, _qwen_model | |
| def _load_siglip(): | |
| global _siglip_proc, _siglip_model | |
| if _siglip_model is None: | |
| with _siglip_lock: | |
| if _siglip_model is None: | |
| from transformers import AutoProcessor, AutoModel | |
| # siglip2 requires transformers>=4.47, use siglip-base which works with 4.40 | |
| _siglip_proc = AutoProcessor.from_pretrained( | |
| "google/siglip-base-patch16-224") | |
| _siglip_model = AutoModel.from_pretrained( | |
| "google/siglip-base-patch16-224") | |
| _siglip_model.eval() | |
| return _siglip_proc, _siglip_model | |
| # ββ Florence-2 helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _florence_describe(img_bytes: bytes) -> dict: | |
| """Returns {title, description, category} from image using Florence-2.""" | |
| import torch, re | |
| proc, model = _load_florence() | |
| img = Image.open(io.BytesIO(img_bytes)).convert("RGB") | |
| # Step 1: get raw caption from Florence | |
| inputs = proc(text="<MORE_DETAILED_CAPTION>", images=img, return_tensors="pt") | |
| with torch.no_grad(): | |
| ids = model.generate( | |
| input_ids=inputs["input_ids"], | |
| pixel_values=inputs["pixel_values"], | |
| max_new_tokens=120, num_beams=3, do_sample=False | |
| ) | |
| caption = proc.batch_decode(ids, skip_special_tokens=True)[0].strip() | |
| # Step 2: extract the core object (first noun phrase β skip scene filler) | |
| # Remove phrases like "The image shows", "It appears", "The background" | |
| clean = re.sub( | |
| r"(the image shows?|it appears?( to be)?|the background( is)?|" | |
| r"the overall mood|the lighting|this is a (photo|image|picture) of)\s*", | |
| "", caption, flags=re.IGNORECASE | |
| ).strip().lstrip(",. ") | |
| # Step 3: build a short human title β just the object, max 6 words | |
| # Take up to first comma or period | |
| short = re.split(r"[,.]", clean)[0].strip() | |
| words = short.split() | |
| title = " ".join(words[:6]).capitalize() | |
| if not title: | |
| title = " ".join(caption.split()[:5]).capitalize() | |
| # Step 4: build a short practical description β 1 sentence max | |
| # Just the key identifying features, no scene-setting | |
| sentences = re.split(r"\. ", clean) | |
| # Pick sentence with most specific detail (longest that mentions size/color/brand) | |
| detail_words = {"black","white","blue","red","green","silver","gold","small","large", | |
| "broken","old","new","leather","metal","plastic","keychain","strap","logo"} | |
| best = sentences[0] | |
| for s in sentences[:3]: | |
| if any(w in s.lower() for w in detail_words): | |
| best = s; break | |
| description = best.strip().rstrip(".") | |
| if len(description) > 120: | |
| description = " ".join(description.split()[:20]) | |
| # Step 5: category from caption | |
| kw_map = { | |
| "phone|laptop|tablet|charger|earphone|headphone|cable|usb|computer": "Electronics", | |
| "bag|backpack|purse|wallet|suitcase|pouch": "Bags", | |
| "watch|ring|necklace|bracelet|glasses|sunglasses|jewelry": "Accessories", | |
| "jacket|shirt|pants|coat|hoodie|cloth|shoe|boot|scarf|hat|cap": "Clothing", | |
| "id|card|passport|license|badge|student": "ID / Cards", | |
| "key|keychain|keyring": "Keys", | |
| } | |
| category = "Other" | |
| cap_lower = caption.lower() | |
| for pattern, cat in kw_map.items(): | |
| if re.search(pattern, cap_lower): | |
| category = cat; break | |
| return {"title": title, "description": description, "category": category} | |
| # ββ Qwen helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _qwen_parse_search(query: str) -> dict: | |
| """Parse natural language search query into structured filters.""" | |
| import torch, re | |
| tok, model = _load_qwen() | |
| system = ( | |
| "You are a search parser for a campus lost and found app. " | |
| "Extract search intent from the user query and return ONLY valid JSON with these fields: " | |
| '{"keywords": "main search terms", "status": "lost|found|waiting|recovered|all", ' | |
| '"location": "location name or empty string", "category": "Electronics|Bags|Accessories|Clothing|ID / Cards|Keys|Other|empty string"}. ' | |
| "No explanation, no markdown, only the JSON object." | |
| ) | |
| messages = [ | |
| {"role": "system", "content": system}, | |
| {"role": "user", "content": query} | |
| ] | |
| text = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| inputs = tok([text], return_tensors="pt") | |
| with torch.no_grad(): | |
| out = model.generate( | |
| **inputs, max_new_tokens=120, | |
| pad_token_id=tok.eos_token_id, do_sample=False | |
| ) | |
| resp = tok.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip() | |
| # extract JSON safely | |
| try: | |
| m = re.search(r'\{.*\}', resp, re.DOTALL) | |
| return json.loads(m.group()) if m else {"keywords": query, "status": "all", "location": "", "category": ""} | |
| except Exception: | |
| return {"keywords": query, "status": "all", "location": "", "category": ""} | |
| # ββ SigLIP2 helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _siglip_embed_image(img_bytes: bytes) -> list: | |
| """Image embedding via SigLIP.""" | |
| import torch | |
| proc, model = _load_siglip() | |
| img = Image.open(io.BytesIO(img_bytes)).convert("RGB") | |
| inputs = proc(images=img, return_tensors="pt") | |
| with torch.no_grad(): | |
| # SigLIPModel.get_image_features returns normalized embeddings directly | |
| feats = model.get_image_features(pixel_values=inputs["pixel_values"]) | |
| feats = feats / feats.norm(dim=-1, keepdim=True) | |
| return feats.squeeze().tolist() | |
| def _siglip_embed_text(text: str) -> list: | |
| """Text embedding via SigLIP β same vector space as image embeddings.""" | |
| import torch | |
| proc, model = _load_siglip() | |
| inputs = proc(text=[text], return_tensors="pt", padding="max_length", truncation=True) | |
| with torch.no_grad(): | |
| feats = model.get_text_features(input_ids=inputs["input_ids"]) | |
| feats = feats / feats.norm(dim=-1, keepdim=True) | |
| return feats.squeeze().tolist() | |
| def _siglip_check_id(img_bytes: bytes) -> dict: | |
| """ | |
| Detect if an image is a university/institution ID card. | |
| Strategy: OCR the image with Florence-2, then look for ID-card keywords. | |
| Much more reliable than SigLIP zero-shot for card detection. | |
| """ | |
| import re, torch | |
| # ββ 1. OCR with Florence-2 ββββββββββββββββββββββββββββββββββββββββββββ | |
| ocr_text = "" | |
| try: | |
| proc, model = _load_florence() | |
| img = Image.open(io.BytesIO(img_bytes)).convert("RGB") | |
| img.thumbnail((768, 768)) | |
| inputs = proc(text="<OCR>", images=img, return_tensors="pt") | |
| with torch.no_grad(): | |
| ids = model.generate( | |
| input_ids=inputs["input_ids"], | |
| pixel_values=inputs["pixel_values"], | |
| max_new_tokens=200, num_beams=1, do_sample=False, | |
| ) | |
| ocr_text = proc.batch_decode(ids, skip_special_tokens=True)[0].strip().lower() | |
| print(f"[id-check] OCR: {repr(ocr_text[:120])}") | |
| except Exception as e: | |
| print(f"[id-check] OCR failed: {e}") | |
| # ββ 2. Keyword scoring ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # High-confidence: these words almost always mean it's an ID card | |
| strong_keywords = [ | |
| "student id", "staff id", "employee id", "faculty id", | |
| "university", "institute of technology", "college", "Γ©cole", | |
| "student card", "id card", "identity card", "access card", | |
| "student", "matricule", "carte Γ©tudiant", "carte d'Γ©tudiant", | |
| "mit id", "campus card", | |
| ] | |
| # Weaker signals β need multiple | |
| weak_keywords = [ | |
| "id", "name", "department", "valid", "expires", "issued", | |
| "badge", "member", "card no", "card #", | |
| ] | |
| strong_hits = [kw for kw in strong_keywords if kw in ocr_text] | |
| weak_hits = [kw for kw in weak_keywords if kw in ocr_text] | |
| # Also check image aspect ratio β ID cards are typically landscape ~1.58:1 | |
| try: | |
| img_check = Image.open(io.BytesIO(img_bytes)) | |
| w, h = img_check.size | |
| ratio = max(w, h) / max(min(w, h), 1) | |
| card_shape = 1.3 <= ratio <= 2.0 | |
| except Exception: | |
| card_shape = False | |
| if strong_hits: | |
| confidence = min(0.95, 0.60 + len(strong_hits) * 0.10) | |
| elif len(weak_hits) >= 2 and card_shape: | |
| confidence = 0.55 | |
| else: | |
| confidence = 0.10 | |
| is_id = confidence >= 0.45 | |
| print(f"[id-check] strong={strong_hits} weak={weak_hits} card_shape={card_shape} β confidence={confidence:.2f} is_id={is_id}") | |
| return {"is_id": is_id, "confidence": round(confidence, 3), "ocr": ocr_text[:200]} | |
| # ββ DB migration: add nudged_at column ββββββββββββββββββββββββββββββββββββββββ | |
| def _migrate_nudge(): | |
| db = get_db() | |
| cols = [r[1] for r in db.execute("PRAGMA table_info(posts)").fetchall()] | |
| if "nudged_at" not in cols: | |
| db.execute("ALTER TABLE posts ADD COLUMN nudged_at TEXT") | |
| db.commit() | |
| db.close() | |
| _migrate_nudge() | |
| # ββ F-E: auto-status cron (runs every 24h) ββββββββββββββββββββββββββββββββββββ | |
| def _run_nudge_cron(): | |
| while True: | |
| time.sleep(60 * 60 * 24) # 24h | |
| try: | |
| db = get_db() | |
| now = datetime.utcnow() | |
| # posts older than 14 days, not recovered, not nudged in last 7 days | |
| rows = db.execute(""" | |
| SELECT p.id, p.title, pr.uid as author_uid | |
| FROM posts p JOIN profiles pr ON pr.id=p.author_id | |
| WHERE p.is_deleted=0 | |
| AND p.status != 'recovered' | |
| AND datetime(p.created_at) < datetime('now', '-14 days') | |
| AND (p.nudged_at IS NULL OR datetime(p.nudged_at) < datetime('now', '-7 days')) | |
| """).fetchall() | |
| for r in rows: | |
| broker.publish(f"user:{r['author_uid']}", { | |
| "type": "nudge", | |
| "post_id": r["id"], | |
| "title": r["title"], | |
| "message": "Is this item still active? Tap to update or mark as recovered." | |
| }) | |
| db.execute("UPDATE posts SET nudged_at=? WHERE id=?", | |
| (now.isoformat(), r["id"])) | |
| db.commit() | |
| db.close() | |
| print(f"[nudge cron] nudged {len(rows)} posts") | |
| except Exception as e: | |
| print(f"[nudge cron error] {e}") | |
| threading.Thread(target=_run_nudge_cron, daemon=True).start() | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ββ NEW ENDPOINTS βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ββ F-A: auto-fill from photo βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def describe_image(file: UploadFile = File(...), user=Depends(get_current_user)): | |
| """Upload a photo β get back {title, description, category} auto-filled by Florence-2.""" | |
| if not file.content_type.startswith("image/"): | |
| raise HTTPException(400, "File must be an image") | |
| try: | |
| result = _florence_describe(await file.read()) | |
| return result | |
| except Exception as e: | |
| raise HTTPException(500, f"Could not describe image: {e}") | |
| # ββ F-B: natural language search βββββββββββββββββββββββββββββββββββββββββββββ | |
| def _rule_parse_search(q: str) -> dict: | |
| """Fast rule-based NL parser β no model needed, instant.""" | |
| import re | |
| q_low = q.lower() | |
| # status | |
| status = "all" | |
| if re.search(r"\blost\b|\bperdu\b|\bΩ ΩΩΩΨ―\b", q_low): status = "lost" | |
| elif re.search(r"\bfound\b|\btrouvΓ©\b|\bΩΨ¬Ψ―\b", q_low): status = "found" | |
| # category | |
| category = "" | |
| cat_map = { | |
| "Electronics": r"phone|laptop|tablet|charger|earphone|headphone|cable|usb|computer|tΓ©lΓ©phone|ordinateur", | |
| "Bags": r"bag|backpack|purse|wallet|sac|cartable|ΨΩΩΨ¨Ψ©", | |
| "Accessories": r"watch|ring|glasses|sunglasses|bracelet|jewelry|montre|lunettes", | |
| "Clothing": r"jacket|shirt|coat|hoodie|scarf|hat|cap|veste|manteau", | |
| "ID / Cards": r"id|card|badge|carte|Ψ¨Ψ·Ψ§ΩΨ©", | |
| "Keys": r"key|keychain|clΓ©|Ω ΩΨͺΨ§Ψ", | |
| } | |
| for cat, pattern in cat_map.items(): | |
| if re.search(pattern, q_low): | |
| category = cat; break | |
| # location β extract word after location prepositions | |
| loc = "" | |
| loc_match = re.search(r"(?:near|at|in|beside|next to|devant|dans|ΨΉΩΨ―|Ψ¨Ψ¬Ψ§ΩΨ¨)\s+([\w\s]+?)(?:\s|$|,|\.)", q_low) | |
| if loc_match: | |
| loc = loc_match.group(1).strip() | |
| # strip status/category/location words from keywords | |
| keywords = q_low | |
| for pat in [r"\b(lost|found|perdu|trouvΓ©)\b", r"\b(near|at|in beside|next to)\b"]: | |
| keywords = re.sub(pat, "", keywords) | |
| keywords = " ".join(keywords.split()) | |
| return {"keywords": keywords, "status": status, "location": loc, "category": category} | |
| async def ai_search(q: str = ""): | |
| """ | |
| Semantic search: embed the text query with SigLIP2 and find posts | |
| whose IMAGE embeddings are closest in the shared vision-language space. | |
| Falls back to text-only posts (no image) using keyword match. | |
| """ | |
| if not q.strip(): | |
| return [] | |
| # ββ Semantic search against post images ββββββββββββββββββββββββββββββββββ | |
| try: | |
| qvec = _siglip_embed_text(q) | |
| except Exception as e: | |
| print(f"[ai/search embed error] {e}") | |
| qvec = None | |
| db = get_db() | |
| # Fetch all non-deleted posts with full author info | |
| rows = db.execute( | |
| "SELECT p.id, p.title, p.description, p.location, p.category, p.status, " | |
| "p.image_url, p.created_at, p.author_id, p.siglip_embedding, " | |
| "pr.uid as author_uid, pr.name as author_name, " | |
| "pr.initials as author_initials, pr.color as author_color, pr.role as author_role, " | |
| "(SELECT COUNT(*) FROM comments c WHERE c.post_id=p.id) as comment_count " | |
| "FROM posts p JOIN profiles pr ON pr.id=p.author_id " | |
| "WHERE p.is_deleted=0" | |
| ).fetchall() | |
| db.close() | |
| scored = [] | |
| fallback = [] # posts with no image β score by title/desc keyword match | |
| for r in rows: | |
| d = dict(r) | |
| emb = d.pop("siglip_embedding", None) | |
| if qvec and emb: | |
| try: | |
| sim = _cosine(qvec, json.loads(emb)) | |
| d["similarity"] = round(sim, 3) | |
| scored.append(d) | |
| except Exception: | |
| pass | |
| else: | |
| # no image embedding β keyword fallback | |
| text = f"{d.get('title','')} {d.get('description','')}".lower() | |
| words = [w for w in q.lower().split() if len(w) > 2] | |
| hits = sum(1 for w in words if w in text) | |
| if hits > 0: | |
| d["similarity"] = round(hits / max(len(words), 1) * 0.3, 3) | |
| fallback.append(d) | |
| # No threshold β textβimage cross-modal scores top out ~0.05-0.15 | |
| # Just sort by similarity and take the top results | |
| scored.sort(key=lambda x: x["similarity"], reverse=True) | |
| scored = [d for d in scored if d["similarity"] > 0] | |
| # Merge: semantic first, then keyword fallback, cap at 10 | |
| results = (scored[:10] + fallback[:5]) | |
| return results | |
| # ββ F-C: live camera search (SigLIP2 image-to-image) βββββββββββββββββββββββββ | |
| async def camera_search( | |
| file: UploadFile = File(...), | |
| status_filter: str = "all" | |
| ): | |
| """Fast image search using SigLIP2 β for live camera scanning.""" | |
| if not file.content_type.startswith("image/"): | |
| raise HTTPException(400, "File must be an image") | |
| try: | |
| qvec = _siglip_embed_image(await file.read()) | |
| except Exception as e: | |
| raise HTTPException(500, f"Could not process image: {e}") | |
| db = get_db() | |
| sql = ( | |
| "SELECT p.id, p.title, p.status, p.location, p.image_url, p.siglip_embedding " | |
| "FROM posts p WHERE p.is_deleted=0 AND p.siglip_embedding IS NOT NULL" | |
| ) | |
| params = [] | |
| if status_filter != "all": | |
| sql += " AND p.status=?"; params.append(status_filter) | |
| rows = db.execute(sql, params).fetchall() | |
| db.close() | |
| scored = [] | |
| for r in rows: | |
| try: | |
| sim = _cosine(qvec, json.loads(r["siglip_embedding"])) | |
| if sim > 0.15: | |
| scored.append({ | |
| "id": r["id"], | |
| "title": r["title"], | |
| "status": r["status"], | |
| "location": r["location"], | |
| "image_url": r["image_url"], | |
| "similarity": round(sim, 3) | |
| }) | |
| except Exception: | |
| pass | |
| scored.sort(key=lambda x: x["similarity"], reverse=True) | |
| return scored[:5] | |
| # ββ F-C: SigLIP migration + backfill βββββββββββββββββββββββββββββββββββββββββ | |
| SIGLIP_MODEL_ID = "google/siglip-base-patch16-224" # update this if model changes | |
| def _migrate_siglip(): | |
| db = get_db() | |
| cols = [r[1] for r in db.execute("PRAGMA table_info(posts)").fetchall()] | |
| if "siglip_embedding" not in cols: | |
| db.execute("ALTER TABLE posts ADD COLUMN siglip_embedding TEXT") | |
| db.commit() | |
| # track which model generated the stored embeddings | |
| db.execute("CREATE TABLE IF NOT EXISTS meta (key TEXT PRIMARY KEY, value TEXT)") | |
| stored = db.execute("SELECT value FROM meta WHERE key='siglip_model'").fetchone() | |
| if stored is None or stored[0] != SIGLIP_MODEL_ID: | |
| db.execute("UPDATE posts SET siglip_embedding=NULL") | |
| db.execute("INSERT OR REPLACE INTO meta (key,value) VALUES ('siglip_model',?)", | |
| (SIGLIP_MODEL_ID,)) | |
| db.commit() | |
| print(f"[siglip] model changed β wiped embeddings, will recompute") | |
| db.commit() | |
| db.close() | |
| _migrate_siglip() | |
| def _backfill_siglip(): | |
| db = get_db() | |
| rows = db.execute( | |
| "SELECT id, image_url FROM posts " | |
| "WHERE image_url IS NOT NULL AND siglip_embedding IS NULL AND is_deleted=0" | |
| ).fetchall() | |
| db.close() | |
| for row in rows: | |
| fname = row["image_url"].lstrip("/images/").lstrip("/") | |
| path = os.path.join(IMG_DIR, fname) | |
| if not os.path.exists(path): continue | |
| try: | |
| vec = _siglip_embed_image(open(path, "rb").read()) | |
| db2 = get_db() | |
| db2.execute("UPDATE posts SET siglip_embedding=? WHERE id=?", | |
| (json.dumps(vec), row["id"])) | |
| db2.commit(); db2.close() | |
| except Exception as e: | |
| print(f"[siglip backfill] {row['id']}: {e}") | |
| threading.Thread(target=_backfill_siglip, daemon=True).start() | |
| # ββ YOLO-World: fast open-vocabulary object detector ββββββββββββββββββββββββ | |
| # ~300-500ms on CPU vs 3-8s for OWLv2. Uses CNN not ViT β much faster. | |
| # YOLOWorld takes a text prompt like "white headphones" β returns bounding boxes. | |
| _yw_model = None | |
| def _extract_yolo_query(text: str) -> str: | |
| """ | |
| Extract the single best YOLO label from a natural language description. | |
| Strategy: find the first concrete object noun after 'a/an/the' or 'of'. | |
| YOLO works best with simple COCO-style labels: 'headphones', 'backpack', 'bottle'. | |
| """ | |
| import re | |
| # Known COCO/YOLO object classes β prefer these if found anywhere in the text | |
| yolo_classes = [ | |
| "headphones","earphones","earbuds","backpack","bag","wallet","purse","phone", | |
| "mobile","laptop","tablet","keyboard","mouse","monitor","charger","cable", | |
| "bottle","cup","mug","glass","book","notebook","pen","pencil","glasses", | |
| "sunglasses","hat","cap","helmet","jacket","shirt","shoe","watch","ring", | |
| "bracelet","necklace","keys","remote","umbrella","ball","toy","box","chair", | |
| "table","desk","sofa","couch","bed","door","window","bicycle","car","person", | |
| "cat","dog","bottle","scissors","knife","fork","spoon","bowl","plate", | |
| ] | |
| text_lower = text.lower() | |
| for cls in yolo_classes: | |
| if cls in text_lower: | |
| return cls | |
| # Fallback: first noun after 'a/an/the' at the start of the sentence | |
| m = re.search(r'\b(?:a|an|the)\s+(?:\w+\s+)?(\w{4,})\b', text_lower) | |
| if m: | |
| noise = {"pair","kind","type","sort","piece","set","lot","bit","group","bunch"} | |
| word = m.group(1) | |
| if word not in noise: | |
| return word | |
| # Last resort: first long word | |
| words = re.findall(r'\b\w{4,}\b', text_lower) | |
| noise = {"this","that","with","have","from","they","some","into","there", | |
| "image","shows","photo","picture","resting","wearing","holding"} | |
| for w in words: | |
| if w not in noise: | |
| return w | |
| return "object" | |
| def _load_yolo_world(): | |
| global _yw_model | |
| if _yw_model is None: | |
| print("[yolo-world] loadingβ¦") | |
| from ultralytics import YOLOWorld as _YW | |
| _yw_model = _YW("yolov8s-worldv2.pt") | |
| # Force CLIP download NOW at startup by setting a dummy class | |
| # This prevents the 338MB download from blocking the first real request | |
| _yw_model.set_classes(["object"]) | |
| global _yw_current_classes | |
| _yw_current_classes = None # reset so real query sets properly | |
| print("[yolo-world] ready β") | |
| return _yw_model | |
| def _yolo_world_find(image: Image.Image, text_query: str, threshold: float = 0.05): | |
| global _yw_current_classes | |
| model = _load_yolo_world() | |
| # Only call set_classes if query changed β avoids re-downloading CLIP every frame | |
| if _yw_current_classes != text_query: | |
| model.set_classes([text_query]) | |
| _yw_current_classes = text_query | |
| results = model.predict(image, conf=threshold, verbose=False) | |
| detections = [] | |
| for r in results: | |
| for box in r.boxes: | |
| detections.append({ | |
| "score": float(box.conf[0]), | |
| "box": [float(x) for x in box.xyxy[0].tolist()], | |
| }) | |
| detections.sort(key=lambda d: d["score"], reverse=True) | |
| return detections | |
| _ref_image_query_cache: dict = {} # md5 β {"query": str, "embedding": list} | |
| _CARD_LIKE_NOUNS = {"card","id","badge","pass","ticket","document","license","permit","certificate"} | |
| def _yolo_world_find_by_image(frame: Image.Image, query_img: Image.Image, threshold: float = 0.01): | |
| """ | |
| Hybrid reference-image finder: | |
| 1. Caption the ref image once with Florence (cached by md5). | |
| 2. Extract core noun. | |
| 3a. If noun is a flat/card-like item YOLO can't detect β SigLIP sliding-window similarity. | |
| 3b. Otherwise β YOLO-World with the noun. | |
| """ | |
| import torch, hashlib, io as _io, numpy as np | |
| buf = _io.BytesIO() | |
| query_img.save(buf, format="JPEG", quality=60) | |
| img_hash = hashlib.md5(buf.getvalue()).hexdigest() | |
| if img_hash not in _ref_image_query_cache: | |
| try: | |
| proc, model = _load_florence() | |
| q = query_img.copy(); q.thumbnail((256, 256)) | |
| inputs = proc(text="<MORE_DETAILED_CAPTION>", images=q, return_tensors="pt") | |
| with torch.no_grad(): | |
| ids = model.generate( | |
| input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], | |
| max_new_tokens=30, num_beams=1, do_sample=False, | |
| ) | |
| caption = proc.batch_decode(ids, skip_special_tokens=True)[0].strip() | |
| query = _extract_yolo_query(caption) | |
| print(f"[ref-image] '{caption}' β '{query}'") | |
| except Exception as e: | |
| print(f"[ref-image error] {e}") | |
| query = "object" | |
| # Pre-compute SigLIP embedding of the ref image (used for sliding window) | |
| ref_buf = _io.BytesIO() | |
| query_img.save(ref_buf, format="JPEG") | |
| ref_emb = _siglip_embed_image(ref_buf.getvalue()) | |
| _ref_image_query_cache[img_hash] = {"query": query, "embedding": ref_emb} | |
| cached = _ref_image_query_cache[img_hash] | |
| query = cached["query"] | |
| ref_emb = cached["embedding"] | |
| # ββ Card/flat items: SigLIP sliding window βββββββββββββββββββββββββββββ | |
| if query in _CARD_LIKE_NOUNS or query == "object": | |
| W, H = frame.size | |
| best_score, best_box = 0.0, None | |
| # Try 3 scales Γ sliding windows | |
| for scale in [0.25, 0.40, 0.60]: | |
| ww, wh = max(60, int(W * scale)), max(40, int(H * scale)) | |
| step_x, step_y = max(20, ww // 3), max(20, wh // 3) | |
| for x in range(0, W - ww + 1, step_x): | |
| for y in range(0, H - wh + 1, step_y): | |
| patch = frame.crop((x, y, x + ww, y + wh)) | |
| pb = _io.BytesIO(); patch.save(pb, format="JPEG", quality=70) | |
| sim = _cosine(ref_emb, _siglip_embed_image(pb.getvalue())) | |
| if sim > best_score: | |
| best_score, best_box = sim, [x, y, x + ww, y + wh] | |
| print(f"[ref-image sliding] best_sim={round(best_score,3)}") | |
| if best_score > 0.70 and best_box: | |
| return [{"score": float(best_score), "box": best_box}] | |
| return [] | |
| # ββ Normal objects: YOLO-World ββββββββββββββββββββββββββββββββββββββββββ | |
| return _yolo_world_find(frame, query, threshold) | |
| # ββ F-C: store SigLIP embedding on upload ββββββββββββββββββββββββββββββββββββ | |
| # Preload YOLO-World in background so first scan is instant | |
| threading.Thread(target=_load_yolo_world, daemon=True).start() | |
| # Patch upload/checked to also compute siglip embedding | |
| _orig_upload_checked = upload_checked.__wrapped__ if hasattr(upload_checked, '__wrapped__') else None | |
| async def upload_checked_v2(file: UploadFile = File(...), user=Depends(require_user)): | |
| """ | |
| Full pipeline upload: | |
| 1. NSFW check | |
| 2. Save file | |
| 3. Compute DINOv2 embedding (for similarity search) | |
| 4. Compute SigLIP2 embedding (for camera search) | |
| """ | |
| if not file.content_type.startswith("image/"): | |
| raise HTTPException(400, "File must be an image") | |
| data = await file.read() | |
| # NSFW gate | |
| try: | |
| if _is_nsfw(data): | |
| raise HTTPException(422, "Image rejected: inappropriate content detected") | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| print(f"[nsfw] {e}") | |
| # save | |
| ext = file.content_type.split("/")[1].replace("jpeg", "jpg") | |
| filename = f"{uuid.uuid4()}.{ext}" | |
| fpath = os.path.join(IMG_DIR, filename) | |
| with open(fpath, "wb") as f: | |
| f.write(data) | |
| url = f"/images/{filename}" | |
| # DINOv2 embedding | |
| try: | |
| _pending_emb[url] = _embed_path(fpath) | |
| except Exception as e: | |
| print(f"[dino embed] {e}") | |
| # SigLIP2 embedding | |
| try: | |
| _pending_siglip[url] = _siglip_embed_image(data) | |
| except Exception as e: | |
| print(f"[siglip embed] {e}") | |
| return {"url": url, "fullUrl": f"{os.environ.get('SPACE_URL', '')}{url}"} | |
| # ββ F-D: admin ID check βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def check_id_image(file: UploadFile = File(...)): | |
| """Check if uploaded image looks like a staff/student ID card using SigLIP2.""" | |
| if not file.content_type.startswith("image/"): | |
| raise HTTPException(400, "File must be an image") | |
| try: | |
| result = _siglip_check_id(await file.read()) | |
| return result | |
| except Exception as e: | |
| raise HTTPException(500, f"Could not analyse image: {e}") | |
| # ββ /ai/find-in-frame β real-world object finder βββββββββββββββββββββββββββββ | |
| async def camera_search_ws(websocket: WebSocket): | |
| import asyncio, base64 | |
| from concurrent.futures import ThreadPoolExecutor | |
| _executor = ThreadPoolExecutor(max_workers=1) | |
| await websocket.accept() | |
| target = None | |
| ref_img = None | |
| async def run_yolo(frame_img): | |
| """Run YOLO in a thread so it doesn't block the async event loop.""" | |
| loop = asyncio.get_event_loop() | |
| if ref_img: | |
| return await loop.run_in_executor(_executor, _yolo_world_find_by_image, frame_img, ref_img) | |
| else: | |
| return await loop.run_in_executor(_executor, _yolo_world_find, frame_img, target) | |
| try: | |
| while True: | |
| data = await websocket.receive() | |
| # ββ Config message βββββββββββββββββββββββββββββββββββββββββββββββ | |
| if "text" in data: | |
| msg = json.loads(data["text"]) | |
| target = msg.get("target", "").strip() | |
| ref_b64 = msg.get("ref_image_b64") | |
| if ref_b64: | |
| ref_img = Image.open(io.BytesIO(base64.b64decode(ref_b64))).convert("RGB") | |
| target = None | |
| else: | |
| ref_img = None | |
| # Warm up in thread β sets classes + downloads CLIP if needed | |
| if target: | |
| loop = asyncio.get_event_loop() | |
| await loop.run_in_executor(_executor, _yolo_world_find, | |
| Image.new("RGB", (64, 64)), target) | |
| await websocket.send_json({"status": "ready"}) | |
| continue | |
| # ββ Frame bytes ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if "bytes" in data: | |
| if not target and not ref_img: | |
| await websocket.send_json({"found": False, "box": None, "confidence": 0.0}) | |
| continue | |
| try: | |
| frame_img = Image.open(io.BytesIO(data["bytes"])).convert("RGB") | |
| W, H = frame_img.size | |
| if W > 640: | |
| s = 640 / W | |
| frame_img = frame_img.resize((640, int(H * s)), Image.BILINEAR) | |
| W, H = frame_img.size | |
| detections = await run_yolo(frame_img) | |
| if detections: | |
| x1, y1, x2, y2 = detections[0]["box"] | |
| await websocket.send_json({ | |
| "found": True, | |
| "box": [x1/W, y1/H, x2/W, y2/H], | |
| "confidence": round(detections[0]["score"], 2), | |
| }) | |
| else: | |
| await websocket.send_json({"found": False, "box": None, "confidence": 0.0}) | |
| except Exception as e: | |
| print(f"[ws frame] {e}") | |
| await websocket.send_json({"found": False, "box": None, "confidence": 0.0}) | |
| except Exception: | |
| _executor.shutdown(wait=False) | |
| # Keep the old HTTP endpoint for backwards compat but just call yolo-world | |
| async def find_in_frame( | |
| frame: UploadFile = File(...), | |
| ref_image: Optional[UploadFile] = File(None), | |
| target: str = Form(""), | |
| ): | |
| frame_bytes = await frame.read() | |
| frame_img = Image.open(io.BytesIO(frame_bytes)).convert("RGB") | |
| W, H = frame_img.size | |
| if W > 640: | |
| scale = 640 / W | |
| frame_img = frame_img.resize((640, int(H * scale)), Image.BILINEAR) | |
| W, H = frame_img.size | |
| target_name = target.strip() if target and target != "__ref_image__" else "your item" | |
| try: | |
| if target and target != "__ref_image__": | |
| # Text query always takes priority β fastest and most accurate | |
| yolo_query = _extract_yolo_query(target) | |
| print(f"[camera] '{target}' β '{yolo_query}'") | |
| detections = _yolo_world_find(frame_img, yolo_query, threshold=0.01) | |
| elif ref_image: | |
| # Image-only mode: describe the ref image once with Florence | |
| import hashlib, io as _refio | |
| ref_bytes = await ref_image.read() | |
| ref_img = Image.open(io.BytesIO(ref_bytes)).convert("RGB") | |
| detections = _yolo_world_find_by_image(frame_img, ref_img) | |
| # Get the actual noun Florence derived (cached after first call) | |
| buf = _refio.BytesIO(); ref_img.save(buf, format="JPEG", quality=60) | |
| yolo_query = (_ref_image_query_cache.get(hashlib.md5(buf.getvalue()).hexdigest()) or {}).get("query", "?") | |
| else: | |
| return {"found": False, "box": None, "label": "", "confidence": 0.0} | |
| print(f"[camera] {W}x{H} β {len(detections)} detections, top={round(detections[0]['score'],2) if detections else 'none'}") | |
| except Exception as e: | |
| print(f"[find-in-frame] {e}") | |
| return {"found": False, "box": None, "label": "", "confidence": 0.0} | |
| if not detections: | |
| return {"found": False, "box": None, "label": yolo_query, "confidence": 0.0} | |
| best = detections[0] | |
| x1, y1, x2, y2 = best["box"] | |
| return {"found": True, "box": [x1/W, y1/H, x2/W, y2/H], "label": yolo_query, "confidence": round(best["score"], 2)} | |
| async def debug_yolo_world(file: UploadFile = File(...), target: str = Form("headphones")): | |
| """Test YOLO-World: upload any photo, specify what to find.""" | |
| import traceback | |
| try: | |
| data = await file.read() | |
| img = Image.open(io.BytesIO(data)).convert("RGB") | |
| detections = _yolo_world_find(img, target, threshold=0.01) | |
| return { | |
| "target": target, | |
| "image_size": f"{img.width}x{img.height}", | |
| "detections": detections[:5], | |
| "model_ready": _yw_model is not None, | |
| "current_classes": _yw_current_classes, | |
| } | |
| except Exception as e: | |
| return {"error": str(e), "traceback": traceback.format_exc()} | |
| async def debug_florence(): | |
| """Check if Florence-2 loads correctly. Remove in production.""" | |
| import sys, traceback | |
| steps = [] | |
| try: | |
| steps.append("importing torch") | |
| import torch | |
| steps.append(f"torch ok β version {torch.__version__}") | |
| steps.append("importing transformers AutoProcessor") | |
| from transformers import AutoProcessor | |
| steps.append("importing transformers AutoModelForCausalLM") | |
| from transformers import AutoModelForCausalLM | |
| steps.append("transformers ok") | |
| steps.append("loading processor from microsoft/Florence-2-base") | |
| proc = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True) | |
| steps.append("processor loaded β") | |
| steps.append("loading model from microsoft/Florence-2-base") | |
| import sys, types, importlib.util | |
| if "flash_attn" not in sys.modules: | |
| stub = types.ModuleType("flash_attn") | |
| stub.__spec__ = importlib.util.spec_from_loader("flash_attn", loader=None) | |
| stub.__version__ = "0.0.0" | |
| stub.flash_attn_func = None | |
| stub.flash_attn_varlen_func = None | |
| sys.modules["flash_attn"] = stub | |
| sub = types.ModuleType("flash_attn.flash_attn_interface") | |
| sub.__spec__ = importlib.util.spec_from_loader("flash_attn.flash_attn_interface", loader=None) | |
| sys.modules["flash_attn.flash_attn_interface"] = sub | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "microsoft/Florence-2-base", trust_remote_code=True, | |
| attn_implementation="eager", torch_dtype=torch.float32, | |
| ) | |
| steps.append("model loaded β") | |
| steps.append("running test inference") | |
| from PIL import Image as PILImage | |
| import io as _io | |
| # tiny 32x32 white image | |
| img = PILImage.new("RGB", (32, 32), color=(255,255,255)) | |
| buf = _io.BytesIO(); img.save(buf, format="JPEG"); buf.seek(0) | |
| inputs = proc(text="<MORE_DETAILED_CAPTION>", images=img, return_tensors="pt") | |
| with torch.no_grad(): | |
| ids = model.generate( | |
| input_ids=inputs["input_ids"], | |
| pixel_values=inputs["pixel_values"], | |
| max_new_tokens=20, num_beams=1, do_sample=False | |
| ) | |
| out = proc.batch_decode(ids, skip_special_tokens=True)[0] | |
| steps.append(f"inference ok β output: '{out}'") | |
| return {"ok": True, "steps": steps} | |
| except Exception as e: | |
| tb = traceback.format_exc() | |
| return {"ok": False, "steps": steps, "error": str(e), "traceback": tb} | |
| # ββ DEBUG: test NL search βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def debug_search(q: str = "lost keys near library"): | |
| filters = _rule_parse_search(q) | |
| return {"query": q, "parsed": filters} | |
| # ββ DEBUG: test semantic search βββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def debug_semantic_search(q: str = "keys on a table"): | |
| import traceback | |
| try: | |
| qvec = _siglip_embed_text(q) | |
| except Exception as e: | |
| return {"ok": False, "error": f"embed failed: {e}", "traceback": traceback.format_exc()} | |
| db = get_db() | |
| rows = db.execute( | |
| "SELECT id, title, image_url, siglip_embedding " | |
| "FROM posts WHERE is_deleted=0 AND siglip_embedding IS NOT NULL" | |
| ).fetchall() | |
| db.close() | |
| scores = [] | |
| for r in rows: | |
| try: | |
| sim = _cosine(qvec, json.loads(r["siglip_embedding"])) | |
| scores.append({"id": r["id"], "title": r["title"], "similarity": round(sim, 4)}) | |
| except Exception as e: | |
| scores.append({"id": r["id"], "title": r["title"], "error": str(e)}) | |
| scores.sort(key=lambda x: x.get("similarity", 0), reverse=True) | |
| db2 = get_db() | |
| stored_model = db2.execute("SELECT value FROM meta WHERE key='siglip_model'").fetchone() | |
| db2.close() | |
| return { | |
| "ok": True, | |
| "query": q, | |
| "model_in_db": stored_model[0] if stored_model else "unknown", | |
| "model_loaded": SIGLIP_MODEL_ID, | |
| "text_vec_len": len(qvec), | |
| "scores": scores | |
| } | |
| async def debug_reembed(): | |
| """Wipe and recompute all siglip embeddings with current model. No auth β remove in prod.""" | |
| global _siglip_proc, _siglip_model | |
| # Force unload cached model so it reloads with correct model id | |
| _siglip_proc = None | |
| _siglip_model = None | |
| db = get_db() | |
| db.execute("UPDATE posts SET siglip_embedding=NULL") | |
| db.execute("INSERT OR REPLACE INTO meta (key,value) VALUES ('siglip_model','__reset__')") | |
| db.commit() | |
| db.close() | |
| _migrate_siglip() | |
| db = get_db() | |
| rows = db.execute( | |
| "SELECT id, image_url FROM posts WHERE image_url IS NOT NULL AND is_deleted=0" | |
| ).fetchall() | |
| db.close() | |
| done, failed, skipped = 0, 0, 0 | |
| for row in rows: | |
| fname = os.path.basename(row["image_url"]) | |
| path = os.path.join(IMG_DIR, fname) | |
| if not os.path.exists(path): | |
| skipped += 1; continue | |
| try: | |
| vec = _siglip_embed_image(open(path, "rb").read()) | |
| db2 = get_db() | |
| db2.execute("UPDATE posts SET siglip_embedding=? WHERE id=?", | |
| (json.dumps(vec), row["id"])) | |
| db2.commit(); db2.close() | |
| done += 1 | |
| except Exception as e: | |
| print(f"[reembed] {row['id']}: {e}"); failed += 1 | |
| return {"reembedded": done, "failed": failed, "skipped": skipped} | |
| # ββ ADMIN: re-embed all posts with current SigLIP model ββββββββββββββββββββββ | |
| async def reembed_siglip(user=Depends(require_admin)): | |
| """ | |
| Wipe all siglip_embedding values and regenerate them using the currently | |
| loaded SigLIP model. Run this whenever you switch SigLIP model versions. | |
| """ | |
| db = get_db() | |
| rows = db.execute( | |
| "SELECT id, image_url FROM posts WHERE is_deleted=0 AND image_url IS NOT NULL" | |
| ).fetchall() | |
| db.close() | |
| done, failed = 0, 0 | |
| for r in rows: | |
| try: | |
| img_path = r["image_url"] | |
| if not img_path.startswith("/"): | |
| img_path = IMG_DIR + "/" + img_path.split("/")[-1] | |
| if not os.path.exists(img_path): | |
| failed += 1; continue | |
| vec = _siglip_embed_image(open(img_path, "rb").read()) | |
| db2 = get_db() | |
| db2.execute("UPDATE posts SET siglip_embedding=? WHERE id=?", | |
| (json.dumps(vec), r["id"])) | |
| db2.commit(); db2.close() | |
| done += 1 | |
| except Exception as e: | |
| print(f"[reembed] {r['id']} failed: {e}") | |
| failed += 1 | |
| return {"reembedded": done, "failed": failed} | |
| # ββ DEBUG: show actual similarity scores ββββββββββββββββββββββββββββββββββββββ | |
| async def debug_scores(q: str = "keys on a table"): | |
| try: | |
| qvec = _siglip_embed_text(q) | |
| except Exception as e: | |
| return {"ok": False, "error": str(e)} | |
| db = get_db() | |
| rows = db.execute( | |
| "SELECT id, title, siglip_embedding FROM posts WHERE is_deleted=0 AND siglip_embedding IS NOT NULL" | |
| ).fetchall() | |
| db.close() | |
| scores = [] | |
| for r in rows: | |
| try: | |
| sim = _cosine(qvec, json.loads(r["siglip_embedding"])) | |
| scores.append({"title": r["title"], "similarity": round(sim, 4)}) | |
| except: | |
| pass | |
| scores.sort(key=lambda x: x["similarity"], reverse=True) | |
| return {"ok": True, "query": q, "scores": scores} | |
| async def debug_embeddings_check(): | |
| """Check what image paths exist vs what DB has.""" | |
| db = get_db() | |
| rows = db.execute( | |
| "SELECT id, title, image_url, " | |
| "CASE WHEN siglip_embedding IS NULL THEN 0 ELSE 1 END as has_emb " | |
| "FROM posts WHERE is_deleted=0" | |
| ).fetchall() | |
| db.close() | |
| results = [] | |
| for r in rows: | |
| url = r["image_url"] or "" | |
| # try all path variants | |
| fname = os.path.basename(url) | |
| path1 = os.path.join(IMG_DIR, fname) | |
| path2 = url if url.startswith("/") else None | |
| exists1 = os.path.exists(path1) | |
| exists2 = os.path.exists(path2) if path2 else False | |
| results.append({ | |
| "title": r["title"], | |
| "image_url": url, | |
| "has_emb": bool(r["has_emb"]), | |
| "path_tried": path1, | |
| "file_exists": exists1 or exists2, | |
| }) | |
| img_files = os.listdir(IMG_DIR) if os.path.exists(IMG_DIR) else [] | |
| return { | |
| "IMG_DIR": IMG_DIR, | |
| "files_in_dir": img_files[:20], | |
| "posts": results | |
| } | |
| # ββ DEBUG: test find-in-frame with a URL βββββββββββββββββββββββββββββββββββββ | |
| async def debug_find_in_frame(target: str = "keys", img_url: str = ""): | |
| """Test the finder using an existing uploaded image URL.""" | |
| import urllib.request, traceback | |
| try: | |
| if img_url.startswith("/images/"): | |
| path = os.path.join(IMG_DIR, os.path.basename(img_url)) | |
| data = open(path, "rb").read() | |
| elif img_url.startswith("http"): | |
| data = urllib.request.urlopen(img_url, timeout=5).read() | |
| else: | |
| # use first post image | |
| db = get_db() | |
| row = db.execute("SELECT image_url FROM posts WHERE image_url IS NOT NULL LIMIT 1").fetchone() | |
| db.close() | |
| if not row: return {"error": "no posts with images"} | |
| path = os.path.join(IMG_DIR, os.path.basename(row["image_url"])) | |
| data = open(path, "rb").read() | |
| img_url = row["image_url"] | |
| from starlette.datastructures import UploadFile as StarletteUpload | |
| import io as _io | |
| # call the logic directly | |
| frame_img = Image.open(_io.BytesIO(data)).convert("RGB") | |
| W, H = frame_img.size | |
| proc, model = _load_florence() | |
| import torch, re | |
| def _run(task, image, text_input=None, max_new=300): | |
| prompt = task if text_input is None else task + text_input | |
| inputs = proc(text=prompt, images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| ids = model.generate( | |
| input_ids=inputs["input_ids"], | |
| pixel_values=inputs["pixel_values"], | |
| max_new_tokens=max_new, num_beams=3, do_sample=False, | |
| early_stopping=False, | |
| ) | |
| raw = proc.batch_decode(ids, skip_special_tokens=False)[0] | |
| return proc.post_process_generation(raw, task=task, image_size=(W, H)) | |
| caption = _run("<MORE_DETAILED_CAPTION>", frame_img).get("<MORE_DETAILED_CAPTION>","") | |
| od_raw = _run("<OD>", frame_img) | |
| od_data = od_raw.get("<OD>", {}) | |
| detections = list(zip(od_data.get("labels",[]), od_data.get("bboxes",[]))) | |
| return { | |
| "img_url": img_url, | |
| "target": target, | |
| "caption": caption, | |
| "detections": [{"label": l, "bbox": b} for l,b in detections], | |
| } | |
| except Exception as e: | |
| return {"error": str(e), "traceback": traceback.format_exc()} |