|
|
from fastapi import FastAPI, File, UploadFile, HTTPException, Depends, Form |
|
|
from fastapi.security import OAuth2PasswordBearer |
|
|
from jose import JWTError, jwt |
|
|
from pinecone import Pinecone |
|
|
import os |
|
|
from dotenv import load_dotenv |
|
|
from PIL import Image |
|
|
import io |
|
|
from transformers import AutoProcessor, CLIPModel |
|
|
import numpy as np |
|
|
from datetime import datetime, timedelta |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
SECRET_KEY = os.getenv("JWT_SECRET", "default_secret") |
|
|
ALGORITHM = "HS256" |
|
|
ACCESS_TOKEN_EXPIRE_MINUTES = 30 |
|
|
|
|
|
|
|
|
fake_users_db = { |
|
|
"admin": { |
|
|
"username": "admin", |
|
|
"password": "password123" |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY") |
|
|
if not PINECONE_API_KEY: |
|
|
raise RuntimeError("PINECONE_API_KEY is not set. Please set it in the environment or .env file.") |
|
|
|
|
|
|
|
|
pc = Pinecone(api_key=PINECONE_API_KEY) |
|
|
index_name = "images-index" |
|
|
unsplash_index = pc.Index(index_name) |
|
|
|
|
|
|
|
|
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") |
|
|
processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
|
model.eval() |
|
|
|
|
|
|
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/token") |
|
|
|
|
|
|
|
|
def create_access_token(data: dict, expires_delta: timedelta = None): |
|
|
to_encode = data.copy() |
|
|
expire = datetime.utcnow() + (expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)) |
|
|
to_encode.update({"exp": expire}) |
|
|
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) |
|
|
|
|
|
|
|
|
def authenticate_user(username: str, password: str): |
|
|
user = fake_users_db.get(username) |
|
|
if not user or user["password"] != password: |
|
|
return None |
|
|
return user |
|
|
|
|
|
|
|
|
def get_current_user(token: str = Depends(oauth2_scheme)): |
|
|
try: |
|
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) |
|
|
username: str = payload.get("sub") |
|
|
if username is None or username not in fake_users_db: |
|
|
raise HTTPException(status_code=401, detail="Invalid authentication") |
|
|
return username |
|
|
except JWTError: |
|
|
raise HTTPException(status_code=401, detail="Invalid authentication") |
|
|
|
|
|
|
|
|
@app.post("/token") |
|
|
async def login(username: str = Form(...), password: str = Form(...)): |
|
|
user = authenticate_user(username, password) |
|
|
if not user: |
|
|
raise HTTPException(status_code=400, detail="Incorrect username or password") |
|
|
access_token = create_access_token(data={"sub": user["username"]}) |
|
|
return {"access_token": access_token, "token_type": "bearer"} |
|
|
|
|
|
|
|
|
def get_text_embedding(text: str): |
|
|
inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True) |
|
|
text_features = model.get_text_features(**inputs) |
|
|
return text_features.detach().cpu().numpy().flatten().tolist() |
|
|
|
|
|
|
|
|
def get_image_embedding(image: Image.Image): |
|
|
inputs = processor(images=image, return_tensors="pt") |
|
|
image_features = model.get_image_features(**inputs) |
|
|
return image_features.detach().cpu().numpy().flatten().tolist() |
|
|
|
|
|
|
|
|
def search_similar_images(embedding: list, top_k: int = 10): |
|
|
results = unsplash_index.query( |
|
|
vector=embedding, |
|
|
top_k=top_k, |
|
|
include_metadata=True, |
|
|
namespace="image-search-dataset" |
|
|
) |
|
|
return results["matches"] |
|
|
|
|
|
|
|
|
@app.get("/search/text/") |
|
|
async def search_by_text(query: str, user: str = Depends(get_current_user)): |
|
|
if not query: |
|
|
raise HTTPException(status_code=400, detail="Query text is required") |
|
|
embedding = get_text_embedding(query) |
|
|
matches = search_similar_images(embedding) |
|
|
return {"matches": [{"id": m["id"], "score": m["score"], "url": m["metadata"]["url"]} for m in matches]} |
|
|
|
|
|
|
|
|
@app.post("/search/image/") |
|
|
async def search_by_image(file: UploadFile = File(...), user: str = Depends(get_current_user)): |
|
|
try: |
|
|
image_data = await file.read() |
|
|
image = Image.open(io.BytesIO(image_data)).convert("RGB") |
|
|
embedding = get_image_embedding(image) |
|
|
matches = search_similar_images(embedding) |
|
|
return {"matches": [{"id": m["id"], "score": m["score"], "url": m["metadata"]["url"]} for m in matches]} |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=7960) |
|
|
|