Spaces:
Runtime error
Runtime error
File size: 3,111 Bytes
099df87 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import os
from typing import List, Dict, Any
from dotenv import load_dotenv
from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct, VectorParams, Distance
from sentence_transformers import SentenceTransformer
# Load environment variables from a .env file if it exists
load_dotenv()
# --- Qdrant Configuration ---
QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333").strip()
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "").strip()
COLLECTION = "reports"
# --- Model Loading (from local files within the Docker image) ---
# This path corresponds to where the Dockerfile copies the model.
MODEL_PATH = "./models/all-MiniLM-L6-v2"
# Initialize the embedding model from the local path.
try:
print(f"Loading sentence-transformer model from local path: {MODEL_PATH}")
embedding_model = SentenceTransformer(MODEL_PATH)
print("✅ Model loaded successfully!")
except Exception as e:
print(f"❌ FATAL: Could not load the embedding model from {MODEL_PATH}.")
print("This indicates an issue with the Docker build or the file path in db.py.")
raise e
# --- Qdrant Client and Collection Setup ---
def _make_client() -> QdrantClient:
"""Creates a Qdrant client based on environment variables."""
if QDRANT_API_KEY:
return QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY, timeout=30.0, check_compatibility=False)
else:
return QdrantClient(url=QDRANT_URL, timeout=30.0, check_compatibility=False)
qdrant = _make_client()
def _ensure_collection():
"""Ensures the Qdrant collection exists, creating it if necessary."""
try:
qdrant.get_collection(collection_name=COLLECTION)
print(f"✅ Collection '{COLLECTION}' already exists.")
except Exception:
print(f"🔧 Collection '{COLLECTION}' not found. Creating it...")
qdrant.create_collection(
collection_name=COLLECTION,
vectors_config=VectorParams(size=384, distance=Distance.COSINE),
)
print("✅ Collection created.")
_ensure_collection()
# --- Database Functions ---
def save_report(report_id: str, text: str, title: str):
"""Encodes and saves a report to Qdrant."""
vector = embedding_model.encode(text).tolist()
qdrant.upsert(
collection_name=COLLECTION,
points=[PointStruct(id=report_id, vector=vector, payload={"text": text, "title": title})],
)
def list_reports() -> List[Dict[str, Any]]:
"""Lists recent reports, including their titles."""
hits, _ = qdrant.scroll(collection_name=COLLECTION, limit=50)
return [{"id": h.id, "title": h.payload.get("title", "(untitled)"), "text": h.payload.get("text", "")} for h in hits]
def search_reports(query: str) -> List[Dict[str, Any]]:
"""Performs semantic search and returns reports with titles."""
vector = embedding_model.encode(query).tolist()
hits = qdrant.search(collection_name=COLLECTION, query_vector=vector, limit=5)
return [{"id": hit.id, "score": float(hit.score), "title": hit.payload.get("title", "(untitled)"), "text": hit.payload.get("text", "")} for hit in hits]
|