|
|
import os |
|
|
import uuid |
|
|
import shutil |
|
|
import sqlite3 |
|
|
import json |
|
|
import logging |
|
|
import asyncio |
|
|
import numpy as np |
|
|
import chromadb |
|
|
import cv2 |
|
|
from datetime import datetime |
|
|
from typing import List, Optional |
|
|
from contextlib import asynccontextmanager |
|
|
|
|
|
|
|
|
from fastapi import FastAPI, UploadFile, File, BackgroundTasks, HTTPException, Request, Form |
|
|
from fastapi.responses import JSONResponse, HTMLResponse |
|
|
from fastapi.staticfiles import StaticFiles |
|
|
from fastapi.templating import Jinja2Templates |
|
|
from pydantic import BaseModel |
|
|
|
|
|
|
|
|
import torch |
|
|
from PIL import Image |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from transformers import BlipProcessor, BlipForConditionalGeneration |
|
|
from insightface.app import FaceAnalysis |
|
|
|
|
|
|
|
|
UPLOAD_DIR = "static/uploads" |
|
|
DB_PATH = "photos.db" |
|
|
CHROMA_PATH = "chroma_db" |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
os.makedirs(UPLOAD_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger("CloudzyAI") |
|
|
|
|
|
|
|
|
ai_models = {} |
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI): |
|
|
|
|
|
logger.info("Loading CLIP model...") |
|
|
ai_models["clip"] = SentenceTransformer('clip-ViT-B-32', device=device) |
|
|
|
|
|
|
|
|
logger.info("Loading BLIP model...") |
|
|
ai_models["blip_processor"] = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") |
|
|
ai_models["blip_model"] = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device) |
|
|
|
|
|
|
|
|
logger.info("Loading InsightFace model...") |
|
|
|
|
|
app_face = FaceAnalysis(name='buffalo_l', providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) |
|
|
app_face.prepare(ctx_id=0, det_size=(640, 640)) |
|
|
ai_models["face"] = app_face |
|
|
|
|
|
|
|
|
init_db() |
|
|
|
|
|
yield |
|
|
logger.info("Shutting down...") |
|
|
|
|
|
app = FastAPI(lifespan=lifespan) |
|
|
app.mount("/static", StaticFiles(directory="static"), name="static") |
|
|
templates = Jinja2Templates(directory="templates") |
|
|
|
|
|
|
|
|
chroma_client = chromadb.PersistentClient(path=CHROMA_PATH) |
|
|
collection = chroma_client.get_or_create_collection(name="photo_embeddings") |
|
|
|
|
|
def init_db(): |
|
|
conn = sqlite3.connect(DB_PATH) |
|
|
cursor = conn.cursor() |
|
|
cursor.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS photos ( |
|
|
id TEXT PRIMARY KEY, |
|
|
filename TEXT, |
|
|
filepath TEXT, |
|
|
upload_date TEXT, |
|
|
caption TEXT, |
|
|
tags TEXT, |
|
|
smart_analysis TEXT, |
|
|
status TEXT |
|
|
) |
|
|
""") |
|
|
conn.commit() |
|
|
conn.close() |
|
|
|
|
|
|
|
|
class PhotoResponse(BaseModel): |
|
|
id: str |
|
|
filename: str |
|
|
url: str |
|
|
caption: Optional[str] = None |
|
|
tags: List[str] = [] |
|
|
smart_features: Optional[dict] = None |
|
|
upload_date: str |
|
|
|
|
|
|
|
|
def process_image_task(photo_id: str, file_path: str): |
|
|
""" |
|
|
Background task that runs the AI pipeline: |
|
|
1. Generate Caption (BLIP) |
|
|
2. Analyze Faces (InsightFace) |
|
|
3. Create Embeddings (CLIP) |
|
|
4. Update DBs |
|
|
""" |
|
|
logger.info(f"Starting AI analysis for {photo_id}") |
|
|
|
|
|
try: |
|
|
|
|
|
pil_image = Image.open(file_path).convert("RGB") |
|
|
cv_image = cv2.imread(file_path) |
|
|
|
|
|
|
|
|
inputs = ai_models["blip_processor"](pil_image, return_tensors="pt").to(device) |
|
|
out = ai_models["blip_model"].generate(**inputs) |
|
|
caption = ai_models["blip_processor"].decode(out[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
faces = ai_models["face"].get(cv_image) |
|
|
face_data = [] |
|
|
tags = ["ai-generated"] |
|
|
|
|
|
if len(faces) > 0: |
|
|
avg_age = np.mean([face.age for face in faces]) |
|
|
gender_counts = {"M": 0, "F": 0} |
|
|
for face in faces: |
|
|
gender = "M" if face.sex == 1 else "F" |
|
|
gender_counts[gender] += 1 |
|
|
face_data.append({ |
|
|
"age": int(face.age), |
|
|
"gender": gender, |
|
|
"confidence": float(face.det_score) |
|
|
}) |
|
|
|
|
|
|
|
|
tags.append("person") |
|
|
tags.append(f"{len(faces)} people") |
|
|
if gender_counts["M"] > gender_counts["F"]: tags.append("mostly_male") |
|
|
if gender_counts["F"] > gender_counts["M"]: tags.append("mostly_female") |
|
|
if avg_age < 18: tags.append("youth") |
|
|
elif avg_age > 60: tags.append("senior") |
|
|
else: tags.append("adult") |
|
|
else: |
|
|
tags.append("scenery") |
|
|
face_data = {"message": "No faces detected"} |
|
|
|
|
|
|
|
|
tags.extend([word for word in caption.split() if len(word) > 4]) |
|
|
tags = list(set(tags)) |
|
|
|
|
|
|
|
|
|
|
|
embedding = ai_models["clip"].encode(pil_image).tolist() |
|
|
|
|
|
|
|
|
conn = sqlite3.connect(DB_PATH) |
|
|
cursor = conn.cursor() |
|
|
cursor.execute(""" |
|
|
UPDATE photos |
|
|
SET caption = ?, tags = ?, smart_analysis = ?, status = 'completed' |
|
|
WHERE id = ? |
|
|
""", (caption, json.dumps(tags), json.dumps(face_data), photo_id)) |
|
|
conn.commit() |
|
|
conn.close() |
|
|
|
|
|
|
|
|
collection.add( |
|
|
ids=[photo_id], |
|
|
embeddings=[embedding], |
|
|
metadatas=[{"caption": caption}] |
|
|
) |
|
|
|
|
|
logger.info(f"AI processing completed for {photo_id}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error processing {photo_id}: {e}") |
|
|
conn = sqlite3.connect(DB_PATH) |
|
|
conn.execute("UPDATE photos SET status = 'failed' WHERE id = ?", (photo_id,)) |
|
|
conn.commit() |
|
|
conn.close() |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
|
async def read_root(request: Request): |
|
|
"""Serve the UI""" |
|
|
return templates.TemplateResponse("index.html", {"request": request}) |
|
|
|
|
|
@app.post("/upload", response_model=PhotoResponse) |
|
|
async def upload_photo(file: UploadFile = File(...), background_tasks: BackgroundTasks = None): |
|
|
""" |
|
|
1. Validate file |
|
|
2. Save to disk |
|
|
3. create DB record |
|
|
4. Trigger Async AI Task |
|
|
""" |
|
|
if not file.content_type.startswith("image/"): |
|
|
raise HTTPException(status_code=400, detail="File must be an image") |
|
|
|
|
|
file_id = str(uuid.uuid4()) |
|
|
ext = file.filename.split(".")[-1] |
|
|
filename = f"{file_id}.{ext}" |
|
|
file_path = os.path.join(UPLOAD_DIR, filename) |
|
|
|
|
|
|
|
|
with open(file_path, "wb") as buffer: |
|
|
shutil.copyfileobj(file.file, buffer) |
|
|
|
|
|
|
|
|
conn = sqlite3.connect(DB_PATH) |
|
|
cursor = conn.cursor() |
|
|
cursor.execute(""" |
|
|
INSERT INTO photos (id, filename, filepath, upload_date, status) |
|
|
VALUES (?, ?, ?, ?, 'processing') |
|
|
""", (file_id, file.filename, file_path, datetime.now().isoformat())) |
|
|
conn.commit() |
|
|
conn.close() |
|
|
|
|
|
|
|
|
background_tasks.add_task(process_image_task, file_id, file_path) |
|
|
|
|
|
return { |
|
|
"id": file_id, |
|
|
"filename": file.filename, |
|
|
"url": f"/static/uploads/{filename}", |
|
|
"upload_date": datetime.now().isoformat() |
|
|
} |
|
|
|
|
|
@app.get("/photo/{photo_id}", response_model=PhotoResponse) |
|
|
async def get_photo(photo_id: str): |
|
|
conn = sqlite3.connect(DB_PATH) |
|
|
conn.row_factory = sqlite3.Row |
|
|
cursor = conn.cursor() |
|
|
row = cursor.execute("SELECT * FROM photos WHERE id = ?", (photo_id,)).fetchone() |
|
|
conn.close() |
|
|
|
|
|
if not row: |
|
|
raise HTTPException(status_code=404, detail="Photo not found") |
|
|
|
|
|
return { |
|
|
"id": row["id"], |
|
|
"filename": row["filename"], |
|
|
"url": f"/{row['filepath']}", |
|
|
"caption": row["caption"], |
|
|
"tags": json.loads(row["tags"]) if row["tags"] else [], |
|
|
"smart_features": json.loads(row["smart_analysis"]) if row["smart_analysis"] else None, |
|
|
"upload_date": row["upload_date"] |
|
|
} |
|
|
|
|
|
@app.get("/search") |
|
|
async def search_photos(q: str): |
|
|
""" |
|
|
Semantic Search: |
|
|
1. Embed query text using CLIP. |
|
|
2. Search ChromaDB for nearest image vectors. |
|
|
3. Retrieve metadata from SQLite. |
|
|
""" |
|
|
|
|
|
query_vec = ai_models["clip"].encode(q).tolist() |
|
|
|
|
|
|
|
|
results = collection.query( |
|
|
query_embeddings=[query_vec], |
|
|
n_results=5 |
|
|
) |
|
|
|
|
|
ids = results["ids"][0] |
|
|
if not ids: |
|
|
return [] |
|
|
|
|
|
|
|
|
placeholders = ",".join("?" * len(ids)) |
|
|
conn = sqlite3.connect(DB_PATH) |
|
|
conn.row_factory = sqlite3.Row |
|
|
cursor = conn.cursor() |
|
|
|
|
|
|
|
|
rows = cursor.execute(f"SELECT * FROM photos WHERE id IN ({placeholders})", ids).fetchall() |
|
|
conn.close() |
|
|
|
|
|
|
|
|
response_data = [] |
|
|
for row in rows: |
|
|
response_data.append({ |
|
|
"id": row["id"], |
|
|
"url": f"/{row['filepath']}", |
|
|
"caption": row["caption"], |
|
|
"tags": json.loads(row["tags"]) if row["tags"] else [], |
|
|
"smart_features": json.loads(row["smart_analysis"]) if row["smart_analysis"] else None, |
|
|
}) |
|
|
|
|
|
return response_data |
|
|
|
|
|
if __name__ == '__main__': |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|
|