Spaces:
Runtime error
Runtime error
| # chromedb_service.py | |
| import io | |
| import requests | |
| import numpy as np | |
| import tensorflow as tf | |
| from PIL import Image | |
| from functools import lru_cache | |
| import chromadb | |
| from tensorflow.keras.applications.mobilenet_v2 import ( | |
| MobileNetV2, | |
| preprocess_input | |
| ) | |
| # ============================================================ | |
| # CONFIG | |
| # ============================================================ | |
| CHROMA_HOST = "https://stable-diffusion-engine.oneiro-lego.com" | |
| IMAGE_SIZE = (224, 224) | |
| TIMEOUT = 5 | |
| # ============================================================ | |
| # CHROMA CLIENT (SINGLETON) | |
| # ============================================================ | |
| chroma_client = chromadb.HttpClient(host=CHROMA_HOST) | |
| # ============================================================ | |
| # MODEL (SINGLETON) | |
| # ============================================================ | |
| model = MobileNetV2( | |
| weights="imagenet", | |
| include_top=False, | |
| pooling="avg" | |
| ) | |
| EMBEDDING_DIM = model.output_shape[-1] | |
| # ============================================================ | |
| # IMAGE UTILS | |
| # ============================================================ | |
| def download_image(url: str) -> Image.Image: | |
| response = requests.get(url, timeout=TIMEOUT) | |
| response.raise_for_status() | |
| return Image.open(io.BytesIO(response.content)).convert("RGB") | |
| def preprocess_image(image: Image.Image) -> np.ndarray: | |
| image = image.resize(IMAGE_SIZE) | |
| array = np.array(image) | |
| array = preprocess_input(array) | |
| return array | |
| # ============================================================ | |
| # IMAGE ENCODING (CACHE) | |
| # ============================================================ | |
| def encode_image_from_url(image_url: str) -> np.ndarray: | |
| image = download_image(image_url) | |
| array = preprocess_image(image) | |
| tensor = tf.convert_to_tensor(array, dtype=tf.float32) | |
| tensor = tf.expand_dims(tensor, axis=0) | |
| embedding = model(tensor, training=False) | |
| return embedding.numpy()[0] | |
| # ============================================================ | |
| # IMAGE API (FLASK COMPATIBLE) | |
| # ============================================================ | |
| def add_image_to_chroma( | |
| collection_name: str, | |
| id: str, | |
| image_url: str, | |
| metadata: dict | |
| ): | |
| vector = encode_image_from_url(image_url) | |
| collection = chroma_client.get_or_create_collection( | |
| name=collection_name, | |
| dimension=EMBEDDING_DIM | |
| ) | |
| collection.add( | |
| ids=[id], | |
| embeddings=[vector.tolist()], | |
| metadatas=[metadata] | |
| ) | |
| # ============================================================ | |
| # TEXT API | |
| # ============================================================ | |
| def add_document( | |
| collection_name: str, | |
| id: str, | |
| text: str, | |
| metadata: dict | |
| ): | |
| collection = chroma_client.get_or_create_collection( | |
| name=collection_name | |
| ) | |
| collection.upsert( | |
| ids=[id], | |
| documents=[text], | |
| metadatas=[metadata] | |
| ) | |
| def delete_document(collection_name: str, id: str): | |
| collection = chroma_client.get_or_create_collection( | |
| name=collection_name | |
| ) | |
| collection.delete(ids=[id]) | |
| def delete_collection(collection_name: str): | |
| chroma_client.delete_collection(name=collection_name) | |
| # ============================================================ | |
| # SEARCH (API ATTENDUE PAR app.py) | |
| # ============================================================ | |
| def search( | |
| collection_name: str, | |
| query: str, | |
| metadata: dict | None, | |
| n_results: int | |
| ): | |
| """ | |
| Recherche TEXTE (compatibilité Flask) | |
| """ | |
| collection = chroma_client.get_or_create_collection( | |
| name=collection_name | |
| ) | |
| results = collection.query( | |
| query_texts=[query], | |
| where=metadata, | |
| n_results=n_results | |
| ) | |
| return parse_chromadb_response(results) | |
| # ============================================================ | |
| # RESPONSE PARSER | |
| # ============================================================ | |
| def parse_chromadb_response(response: dict) -> list[dict]: | |
| output = [] | |
| for i in range(len(response["ids"][0])): | |
| distance = float(response["distances"][0][i]) | |
| score = round(1 / (1 + distance), 4) | |
| output.append({ | |
| "id": response["ids"][0][i], | |
| "distance": distance, | |
| "score": score, | |
| "document": ( | |
| response["documents"][0][i] | |
| if response.get("documents") else None | |
| ), | |
| "metadata": ( | |
| response["metadatas"][0][i] | |
| if response.get("metadatas") else None | |
| ) | |
| }) | |
| # 🔥 TRI SERVEUR PAR SCORE (DESC) | |
| output.sort(key=lambda x: x["score"], reverse=True) | |
| return output | |