Spaces:
Paused
Paused
replace sentence-transformers with OpenAI embeddings — fixes slow HF build
Browse files- Dockerfile +0 -7
- agent.py +20 -10
- requirements.txt +0 -1
Dockerfile
CHANGED
|
@@ -1,21 +1,14 @@
|
|
| 1 |
FROM python:3.11-slim
|
| 2 |
|
| 3 |
-
# HF Spaces runs on port 7860
|
| 4 |
ENV PORT=7860
|
| 5 |
ENV PYTHONDONTWRITEBYTECODE=1
|
| 6 |
ENV PYTHONUNBUFFERED=1
|
| 7 |
|
| 8 |
WORKDIR /app
|
| 9 |
|
| 10 |
-
# Install dependencies first (Docker cache layer)
|
| 11 |
COPY requirements.txt .
|
| 12 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 13 |
|
| 14 |
-
# Pre-download the embedding model at build time
|
| 15 |
-
# so cold starts are fast (no download on first request)
|
| 16 |
-
RUN python -c "from sentence_transformers import SentenceTransformer; SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')"
|
| 17 |
-
|
| 18 |
-
# Copy source
|
| 19 |
COPY . .
|
| 20 |
|
| 21 |
EXPOSE 7860
|
|
|
|
| 1 |
FROM python:3.11-slim
|
| 2 |
|
|
|
|
| 3 |
ENV PORT=7860
|
| 4 |
ENV PYTHONDONTWRITEBYTECODE=1
|
| 5 |
ENV PYTHONUNBUFFERED=1
|
| 6 |
|
| 7 |
WORKDIR /app
|
| 8 |
|
|
|
|
| 9 |
COPY requirements.txt .
|
| 10 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
COPY . .
|
| 13 |
|
| 14 |
EXPOSE 7860
|
agent.py
CHANGED
|
@@ -22,13 +22,14 @@ import numpy as np
|
|
| 22 |
import faiss
|
| 23 |
import requests
|
| 24 |
from openai import AsyncOpenAI
|
| 25 |
-
from sentence_transformers import SentenceTransformer
|
| 26 |
|
| 27 |
from knowledge import DOCUMENTS
|
| 28 |
import session as session_store
|
| 29 |
|
| 30 |
logger = logging.getLogger(__name__)
|
| 31 |
|
|
|
|
|
|
|
| 32 |
# ── Config ────────────────────────────────────────────────────────────────────
|
| 33 |
OPENAI_MODEL = "gpt-4o-mini"
|
| 34 |
TOP_K_CHUNKS = 3
|
|
@@ -135,23 +136,32 @@ class SulithaAgent:
|
|
| 135 |
self.name = "Sulitha Nulaksha Bandara"
|
| 136 |
self.openai = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY", ""))
|
| 137 |
|
| 138 |
-
logger.info("
|
| 139 |
-
self._embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
| 140 |
-
|
| 141 |
-
logger.info(f"Embedding {len(DOCUMENTS)} knowledge chunks...")
|
| 142 |
self._doc_chunks = DOCUMENTS
|
| 143 |
-
embeddings = self.
|
| 144 |
-
embeddings = embeddings.astype(np.float32)
|
| 145 |
|
| 146 |
dim = embeddings.shape[1]
|
| 147 |
self._index = faiss.IndexFlatL2(dim)
|
| 148 |
self._index.add(embeddings)
|
| 149 |
logger.info(f"FAISS index built: {self._index.ntotal} vectors, dim={dim}")
|
| 150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
# ── RAG ───────────────────────────────────────────────────────────────────
|
| 152 |
|
| 153 |
-
def _retrieve(self, query: str, k: int = TOP_K_CHUNKS) -> str:
|
| 154 |
-
|
|
|
|
|
|
|
| 155 |
_, idxs = self._index.search(q_vec, k)
|
| 156 |
chunks = [self._doc_chunks[i] for i in idxs[0] if i < len(self._doc_chunks)]
|
| 157 |
return "\n\n---\n\n".join(chunks)
|
|
@@ -198,7 +208,7 @@ class SulithaAgent:
|
|
| 198 |
if session_store.is_over_limit(session_id):
|
| 199 |
return SESSION_LIMIT_REPLY
|
| 200 |
|
| 201 |
-
rag_context = self._retrieve(message)
|
| 202 |
history = session_store.get_history(session_id)
|
| 203 |
|
| 204 |
messages = (
|
|
|
|
| 22 |
import faiss
|
| 23 |
import requests
|
| 24 |
from openai import AsyncOpenAI
|
|
|
|
| 25 |
|
| 26 |
from knowledge import DOCUMENTS
|
| 27 |
import session as session_store
|
| 28 |
|
| 29 |
logger = logging.getLogger(__name__)
|
| 30 |
|
| 31 |
+
EMBED_MODEL = "text-embedding-3-small" # fast, cheap, 1536-dim
|
| 32 |
+
|
| 33 |
# ── Config ────────────────────────────────────────────────────────────────────
|
| 34 |
OPENAI_MODEL = "gpt-4o-mini"
|
| 35 |
TOP_K_CHUNKS = 3
|
|
|
|
| 136 |
self.name = "Sulitha Nulaksha Bandara"
|
| 137 |
self.openai = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY", ""))
|
| 138 |
|
| 139 |
+
logger.info(f"Embedding {len(DOCUMENTS)} knowledge chunks via OpenAI...")
|
|
|
|
|
|
|
|
|
|
| 140 |
self._doc_chunks = DOCUMENTS
|
| 141 |
+
embeddings = self._embed_texts(DOCUMENTS)
|
|
|
|
| 142 |
|
| 143 |
dim = embeddings.shape[1]
|
| 144 |
self._index = faiss.IndexFlatL2(dim)
|
| 145 |
self._index.add(embeddings)
|
| 146 |
logger.info(f"FAISS index built: {self._index.ntotal} vectors, dim={dim}")
|
| 147 |
|
| 148 |
+
# ── Embeddings (OpenAI) ───────────────────────────────────────────────────
|
| 149 |
+
|
| 150 |
+
def _embed_texts(self, texts: list[str]) -> np.ndarray:
|
| 151 |
+
"""Embed a list of texts using OpenAI embeddings. Returns float32 array."""
|
| 152 |
+
# Use a sync client just for the startup index build
|
| 153 |
+
from openai import OpenAI as SyncOpenAI
|
| 154 |
+
client = SyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY", ""))
|
| 155 |
+
response = client.embeddings.create(model=EMBED_MODEL, input=texts)
|
| 156 |
+
vecs = np.array([d.embedding for d in response.data], dtype=np.float32)
|
| 157 |
+
return vecs
|
| 158 |
+
|
| 159 |
# ── RAG ───────────────────────────────────────────────────────────────────
|
| 160 |
|
| 161 |
+
async def _retrieve(self, query: str, k: int = TOP_K_CHUNKS) -> str:
|
| 162 |
+
"""Embed query via OpenAI, search FAISS, return top-k chunks."""
|
| 163 |
+
response = await self.openai.embeddings.create(model=EMBED_MODEL, input=[query])
|
| 164 |
+
q_vec = np.array([response.data[0].embedding], dtype=np.float32)
|
| 165 |
_, idxs = self._index.search(q_vec, k)
|
| 166 |
chunks = [self._doc_chunks[i] for i in idxs[0] if i < len(self._doc_chunks)]
|
| 167 |
return "\n\n---\n\n".join(chunks)
|
|
|
|
| 208 |
if session_store.is_over_limit(session_id):
|
| 209 |
return SESSION_LIMIT_REPLY
|
| 210 |
|
| 211 |
+
rag_context = await self._retrieve(message)
|
| 212 |
history = session_store.get_history(session_id)
|
| 213 |
|
| 214 |
messages = (
|
requirements.txt
CHANGED
|
@@ -3,7 +3,6 @@ fastapi==0.115.0
|
|
| 3 |
uvicorn[standard]==0.30.6
|
| 4 |
pydantic==2.8.2
|
| 5 |
openai==1.51.0
|
| 6 |
-
sentence-transformers==3.1.1
|
| 7 |
faiss-cpu==1.9.0
|
| 8 |
numpy==1.26.4
|
| 9 |
httpx==0.27.2
|
|
|
|
| 3 |
uvicorn[standard]==0.30.6
|
| 4 |
pydantic==2.8.2
|
| 5 |
openai==1.51.0
|
|
|
|
| 6 |
faiss-cpu==1.9.0
|
| 7 |
numpy==1.26.4
|
| 8 |
httpx==0.27.2
|