AdarshDRC commited on
Commit
3e805ab
·
verified ·
1 Parent(s): 6dc24b0

Upload 5 files

Browse files
Files changed (5) hide show
  1. Dockerfile +22 -0
  2. main.py +104 -0
  3. requirements.txt +93 -0
  4. src/cloud_db.py +58 -0
  5. src/models.py +58 -0
Dockerfile ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime
2
+ FROM python:3.10
3
+
4
+ # Set the working directory
5
+ WORKDIR /app
6
+
7
+ # Copy requirements and install them
8
+ COPY requirements.txt .
9
+ RUN pip install --no-cache-dir -r requirements.txt
10
+
11
+ # Copy the rest of the backend code
12
+ COPY . .
13
+
14
+ # Create the temp directory and give it permission to save images
15
+ RUN mkdir -p temp_uploads
16
+ RUN chmod -R 777 temp_uploads
17
+
18
+ # Hugging Face requires apps to run on port 7860
19
+ EXPOSE 7860
20
+
21
+ # Start the FastAPI server
22
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
main.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from typing import List
4
+ from PIL import Image
5
+ import os
6
+ import shutil
7
+ import uuid
8
+ import re
9
+ import inflect # <-- NEW: Import inflect
10
+
11
+ from src.models import AIModelManager
12
+ from src.cloud_db import CloudDB
13
+
14
+ app = FastAPI()
15
+
16
+ app.add_middleware(
17
+ CORSMiddleware,
18
+ allow_origins=["*"],
19
+ allow_credentials=True,
20
+ allow_methods=["*"],
21
+ allow_headers=["*"],
22
+ )
23
+
24
+ print("Loading AI Models and Cloud DB...")
25
+ ai = AIModelManager()
26
+ db = CloudDB()
27
+ # Initialize the inflect engine
28
+ p = inflect.engine()
29
+ print("Ready!")
30
+
31
+ os.makedirs("temp_uploads", exist_ok=True)
32
+
33
+ # --- NEW: Standardization Function ---
34
+ def standardize_category_name(name: str) -> str:
35
+ """Converts ' Cows ', 'COWS', or 'cow' all into 'cow'."""
36
+ # 1. Lowercase and strip accidental edge spaces
37
+ clean_name = name.strip().lower()
38
+
39
+ # 2. Replace inner spaces with underscores (e.g., 'sports cars' -> 'sports_cars')
40
+ clean_name = re.sub(r'\s+', '_', clean_name)
41
+
42
+ # 3. Remove weird special characters just in case (keep only letters, numbers, underscores)
43
+ clean_name = re.sub(r'[^\w\s]', '', clean_name)
44
+
45
+ # 4. Convert plural to singular (if it's already singular, it returns False, so we keep the clean_name)
46
+ singular_name = p.singular_noun(clean_name)
47
+ if singular_name:
48
+ return singular_name
49
+
50
+ return clean_name
51
+ # -------------------------------------
52
+
53
+ @app.post("/api/upload")
54
+ async def upload_new_images(files: List[UploadFile] = File(...), folder_name: str = Form(...)):
55
+ """Handles bulk uploading of multiple images at once."""
56
+ uploaded_urls = []
57
+
58
+ # Clean the folder name before doing anything else!
59
+ standardized_folder = standardize_category_name(folder_name)
60
+
61
+ try:
62
+ for file in files:
63
+ temp_path = f"temp_uploads/{file.filename}"
64
+ with open(temp_path, "wb") as buffer:
65
+ shutil.copyfileobj(file.file, buffer)
66
+
67
+ # Upload to Cloudinary using the perfectly clean folder name
68
+ image_url = db.upload_image(temp_path, standardized_folder)
69
+
70
+ img = Image.open(temp_path).convert('RGB')
71
+ vector = ai.encode_image(img)
72
+
73
+ image_id = str(uuid.uuid4())
74
+ db.add_vector(vector, image_url, image_id)
75
+
76
+ os.remove(temp_path)
77
+ uploaded_urls.append(image_url)
78
+
79
+ # Return the standardized name so the frontend knows what was actually saved
80
+ return {
81
+ "message": f"Successfully added {len(files)} images to category '{standardized_folder}'!",
82
+ "urls": uploaded_urls
83
+ }
84
+
85
+ except Exception as e:
86
+ raise HTTPException(status_code=500, detail=str(e))
87
+
88
+ # ... (Your /api/search endpoint stays exactly the same) ...
89
+ @app.post("/api/search")
90
+ async def search_database(file: UploadFile = File(...)):
91
+ try:
92
+ temp_path = f"temp_uploads/query_{file.filename}"
93
+ with open(temp_path, "wb") as buffer:
94
+ shutil.copyfileobj(file.file, buffer)
95
+
96
+ img = Image.open(temp_path).convert('RGB')
97
+ vector = ai.encode_image(img)
98
+
99
+ results = db.search(vector, top_k=10)
100
+ os.remove(temp_path)
101
+
102
+ return {"results": results}
103
+ except Exception as e:
104
+ raise HTTPException(status_code=500, detail=str(e))
requirements.txt ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ annotated-doc==0.0.4
2
+ annotated-types==0.7.0
3
+ anyio==4.12.1
4
+ certifi==2026.2.25
5
+ charset-normalizer==3.4.4
6
+ click==8.3.1
7
+ cloudinary==1.44.1
8
+ contourpy==1.3.3
9
+ cuda-bindings==12.9.4
10
+ cuda-pathfinder==1.4.0
11
+ cycler==0.12.1
12
+ fastapi==0.135.1
13
+ filelock==3.25.0
14
+ fonttools==4.61.1
15
+ fsspec==2026.2.0
16
+ h11==0.16.0
17
+ hf-xet==1.3.2
18
+ httpcore==1.0.9
19
+ httpx==0.28.1
20
+ huggingface_hub==0.36.2
21
+ idna==3.11
22
+ inflect==7.5.0
23
+ Jinja2==3.1.6
24
+ kiwisolver==1.4.9
25
+ markdown-it-py==4.0.0
26
+ MarkupSafe==3.0.3
27
+ matplotlib==3.10.8
28
+ mdurl==0.1.2
29
+ more-itertools==10.8.0
30
+ mpmath==1.3.0
31
+ networkx==3.6.1
32
+ numpy==2.4.2
33
+ nvidia-cublas-cu12==12.8.4.1
34
+ nvidia-cuda-cupti-cu12==12.8.90
35
+ nvidia-cuda-nvrtc-cu12==12.8.93
36
+ nvidia-cuda-runtime-cu12==12.8.90
37
+ nvidia-cudnn-cu12==9.10.2.21
38
+ nvidia-cufft-cu12==11.3.3.83
39
+ nvidia-cufile-cu12==1.13.1.3
40
+ nvidia-curand-cu12==10.3.9.90
41
+ nvidia-cusolver-cu12==11.7.3.90
42
+ nvidia-cusparse-cu12==12.5.8.93
43
+ nvidia-cusparselt-cu12==0.7.1
44
+ nvidia-nccl-cu12==2.27.5
45
+ nvidia-nvjitlink-cu12==12.8.93
46
+ nvidia-nvshmem-cu12==3.4.5
47
+ nvidia-nvtx-cu12==12.8.90
48
+ opencv-python==4.13.0.92
49
+ orjson==3.11.7
50
+ packaging==24.2
51
+ pillow==12.1.1
52
+ pinecone==8.1.0
53
+ pinecone-client==6.0.0
54
+ pinecone-plugin-assistant==3.0.2
55
+ pinecone-plugin-interface==0.0.7
56
+ polars==1.38.1
57
+ polars-runtime-32==1.38.1
58
+ protobuf==7.34.0
59
+ psutil==7.2.2
60
+ pydantic==2.12.5
61
+ pydantic_core==2.41.5
62
+ Pygments==2.19.2
63
+ pyparsing==3.3.2
64
+ python-dateutil==2.9.0.post0
65
+ python-dotenv==1.2.2
66
+ python-multipart==0.0.22
67
+ PyYAML==6.0.3
68
+ regex==2026.2.28
69
+ requests==2.32.5
70
+ rich==14.3.3
71
+ safetensors==0.7.0
72
+ scipy==1.17.1
73
+ sentencepiece==0.2.1
74
+ setuptools==82.0.0
75
+ shellingham==1.5.4
76
+ six==1.17.0
77
+ starlette==0.52.1
78
+ sympy==1.14.0
79
+ tokenizers==0.21.4
80
+ torch==2.10.0
81
+ torchvision==0.25.0
82
+ tqdm==4.67.3
83
+ transformers==4.48.0
84
+ triton==3.6.0
85
+ typeguard==4.5.1
86
+ typer==0.24.1
87
+ typer-slim==0.24.0
88
+ typing-inspection==0.4.2
89
+ typing_extensions==4.15.0
90
+ ultralytics==8.4.19
91
+ ultralytics-thop==2.0.18
92
+ urllib3==2.6.3
93
+ uvicorn==0.41.0
src/cloud_db.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cloudinary
3
+ import cloudinary.uploader
4
+ from pinecone import Pinecone
5
+ from dotenv import load_dotenv
6
+
7
+ # Load keys from the .env file
8
+ load_dotenv()
9
+
10
+ class CloudDB:
11
+ def __init__(self):
12
+ # 1. Connect to Cloudinary
13
+ cloudinary.config(
14
+ cloud_name=os.getenv("CLOUDINARY_CLOUD_NAME"),
15
+ api_key=os.getenv("CLOUDINARY_API_KEY"),
16
+ api_secret=os.getenv("CLOUDINARY_API_SECRET")
17
+ )
18
+
19
+ # 2. Connect to Pinecone
20
+ self.pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
21
+ self.index = self.pc.Index(os.getenv("PINECONE_INDEX_NAME"))
22
+
23
+ def upload_image(self, file_path, folder_name="visual_search"):
24
+ """Uploads an image to Cloudinary and returns the public URL."""
25
+ response = cloudinary.uploader.upload(file_path, folder=folder_name)
26
+ return response['secure_url']
27
+
28
+ def add_vector(self, vector, image_url, image_id):
29
+ """Saves the vector and the image URL to Pinecone."""
30
+ # Convert numpy array to list for Pinecone
31
+ vector_list = vector.tolist() if hasattr(vector, 'tolist') else vector
32
+
33
+ self.index.upsert(vectors=[{
34
+ "id": image_id,
35
+ "values": vector_list,
36
+ "metadata": {"image_url": image_url}
37
+ }])
38
+
39
+ def search(self, query_vector, top_k=10, min_score=0.60): # <-- CHANGED baseline to 0.60
40
+ """Searches Pinecone and filters out baseline 'random noise' matches."""
41
+ vector_list = query_vector.tolist() if hasattr(query_vector, 'tolist') else query_vector
42
+
43
+ response = self.index.query(
44
+ vector=vector_list,
45
+ top_k=top_k,
46
+ include_metadata=True
47
+ )
48
+
49
+ results = []
50
+ for match in response['matches']:
51
+ # Only keep the image if it's an ACTUAL mathematical match (60% or higher)
52
+ if match['score'] >= min_score:
53
+ results.append({
54
+ "url": match['metadata']['image_url'],
55
+ "score": match['score']
56
+ })
57
+
58
+ return results
src/models.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/models.py
2
+ import torch
3
+ from PIL import Image
4
+ from transformers import AutoProcessor, AutoModel
5
+ from ultralytics import YOLO
6
+
7
+ class AIModelManager:
8
+ def __init__(self):
9
+ # Load SigLIP (Vision & Text Encoder)
10
+ self.processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224",use_fast=False)
11
+ self.model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
12
+ self.model.eval() # Set to evaluation mode
13
+
14
+ # Load YOLOv11 (Nano version for speed)
15
+ self.yolo = YOLO('yolov8n.pt') # Will auto-download the tiny weights
16
+
17
+ def encode_image(self, image: Image.Image):
18
+ """Converts a PIL Image into a vector."""
19
+ inputs = self.processor(images=image, return_tensors="pt")
20
+ with torch.no_grad():
21
+ outputs = self.model.get_image_features(**inputs)
22
+
23
+ # Extract the raw tensor from the output object
24
+ if hasattr(outputs, 'image_embeds'):
25
+ image_features = outputs.image_embeds
26
+ elif hasattr(outputs, 'pooler_output'):
27
+ image_features = outputs.pooler_output
28
+ else:
29
+ image_features = outputs
30
+
31
+ return image_features.flatten().numpy()
32
+
33
+ def encode_text(self, text: str):
34
+ """Converts a text string into a vector."""
35
+ inputs = self.processor(text=text, return_tensors="pt", padding="max_length")
36
+ with torch.no_grad():
37
+ outputs = self.model.get_text_features(**inputs)
38
+
39
+ # Hugging Face quirk: Extract the raw tensor from the output object
40
+ if hasattr(outputs, 'text_embeds'):
41
+ text_features = outputs.text_embeds
42
+ elif hasattr(outputs, 'pooler_output'):
43
+ text_features = outputs.pooler_output
44
+ else:
45
+ text_features = outputs
46
+
47
+ return text_features.flatten().numpy()
48
+
49
+ def get_crops_from_image(self, image: Image.Image):
50
+ """Uses YOLO to find objects and returns a list of cropped PIL Images."""
51
+ results = self.yolo(image, conf=0.5) # Only keep confident detections
52
+ crops = []
53
+ for result in results:
54
+ for box in result.boxes.xyxy: # Get bounding box coordinates
55
+ x1, y1, x2, y2 = map(int, box.tolist())
56
+ cropped_img = image.crop((x1, y1, x2, y2))
57
+ crops.append(cropped_img)
58
+ return crops