chromadb-connect-api / chromedb_service.py
Saad0KH's picture
Update chromedb_service.py
c6bd289 verified
# chromedb_service.py
import io
import requests
import numpy as np
import tensorflow as tf
from PIL import Image
from functools import lru_cache
import chromadb
from tensorflow.keras.applications.mobilenet_v2 import (
MobileNetV2,
preprocess_input
)
# ============================================================
# CONFIG
# ============================================================
CHROMA_HOST = "https://stable-diffusion-engine.oneiro-lego.com"
IMAGE_SIZE = (224, 224)
TIMEOUT = 5
# ============================================================
# CHROMA CLIENT (SINGLETON)
# ============================================================
chroma_client = chromadb.HttpClient(host=CHROMA_HOST)
# ============================================================
# MODEL (SINGLETON)
# ============================================================
model = MobileNetV2(
weights="imagenet",
include_top=False,
pooling="avg"
)
EMBEDDING_DIM = model.output_shape[-1]
# ============================================================
# IMAGE UTILS
# ============================================================
def download_image(url: str) -> Image.Image:
response = requests.get(url, timeout=TIMEOUT)
response.raise_for_status()
return Image.open(io.BytesIO(response.content)).convert("RGB")
def preprocess_image(image: Image.Image) -> np.ndarray:
image = image.resize(IMAGE_SIZE)
array = np.array(image)
array = preprocess_input(array)
return array
# ============================================================
# IMAGE ENCODING (CACHE)
# ============================================================
@lru_cache(maxsize=1024)
def encode_image_from_url(image_url: str) -> np.ndarray:
image = download_image(image_url)
array = preprocess_image(image)
tensor = tf.convert_to_tensor(array, dtype=tf.float32)
tensor = tf.expand_dims(tensor, axis=0)
embedding = model(tensor, training=False)
return embedding.numpy()[0]
# ============================================================
# IMAGE API (FLASK COMPATIBLE)
# ============================================================
def add_image_to_chroma(
collection_name: str,
id: str,
image_url: str,
metadata: dict
):
vector = encode_image_from_url(image_url)
collection = chroma_client.get_or_create_collection(
name=collection_name,
dimension=EMBEDDING_DIM
)
collection.add(
ids=[id],
embeddings=[vector.tolist()],
metadatas=[metadata]
)
# ============================================================
# TEXT API
# ============================================================
def add_document(
collection_name: str,
id: str,
text: str,
metadata: dict
):
collection = chroma_client.get_or_create_collection(
name=collection_name
)
collection.upsert(
ids=[id],
documents=[text],
metadatas=[metadata]
)
def delete_document(collection_name: str, id: str):
collection = chroma_client.get_or_create_collection(
name=collection_name
)
collection.delete(ids=[id])
def delete_collection(collection_name: str):
chroma_client.delete_collection(name=collection_name)
# ============================================================
# SEARCH (API ATTENDUE PAR app.py)
# ============================================================
def search(
collection_name: str,
query: str,
metadata: dict | None,
n_results: int
):
"""
Recherche TEXTE (compatibilité Flask)
"""
collection = chroma_client.get_or_create_collection(
name=collection_name
)
results = collection.query(
query_texts=[query],
where=metadata,
n_results=n_results
)
return parse_chromadb_response(results)
# ============================================================
# RESPONSE PARSER
# ============================================================
def parse_chromadb_response(response: dict) -> list[dict]:
output = []
for i in range(len(response["ids"][0])):
distance = float(response["distances"][0][i])
score = round(1 / (1 + distance), 4)
output.append({
"id": response["ids"][0][i],
"distance": distance,
"score": score,
"document": (
response["documents"][0][i]
if response.get("documents") else None
),
"metadata": (
response["metadatas"][0][i]
if response.get("metadatas") else None
)
})
# 🔥 TRI SERVEUR PAR SCORE (DESC)
output.sort(key=lambda x: x["score"], reverse=True)
return output