Sulitha commited on
Commit
4fc1bf5
·
1 Parent(s): be3f7b3

replace sentence-transformers with OpenAI embeddings — fixes slow HF build

Browse files
Files changed (3) hide show
  1. Dockerfile +0 -7
  2. agent.py +20 -10
  3. requirements.txt +0 -1
Dockerfile CHANGED
@@ -1,21 +1,14 @@
1
  FROM python:3.11-slim
2
 
3
- # HF Spaces runs on port 7860
4
  ENV PORT=7860
5
  ENV PYTHONDONTWRITEBYTECODE=1
6
  ENV PYTHONUNBUFFERED=1
7
 
8
  WORKDIR /app
9
 
10
- # Install dependencies first (Docker cache layer)
11
  COPY requirements.txt .
12
  RUN pip install --no-cache-dir -r requirements.txt
13
 
14
- # Pre-download the embedding model at build time
15
- # so cold starts are fast (no download on first request)
16
- RUN python -c "from sentence_transformers import SentenceTransformer; SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')"
17
-
18
- # Copy source
19
  COPY . .
20
 
21
  EXPOSE 7860
 
1
  FROM python:3.11-slim
2
 
 
3
  ENV PORT=7860
4
  ENV PYTHONDONTWRITEBYTECODE=1
5
  ENV PYTHONUNBUFFERED=1
6
 
7
  WORKDIR /app
8
 
 
9
  COPY requirements.txt .
10
  RUN pip install --no-cache-dir -r requirements.txt
11
 
 
 
 
 
 
12
  COPY . .
13
 
14
  EXPOSE 7860
agent.py CHANGED
@@ -22,13 +22,14 @@ import numpy as np
22
  import faiss
23
  import requests
24
  from openai import AsyncOpenAI
25
- from sentence_transformers import SentenceTransformer
26
 
27
  from knowledge import DOCUMENTS
28
  import session as session_store
29
 
30
  logger = logging.getLogger(__name__)
31
 
 
 
32
  # ── Config ────────────────────────────────────────────────────────────────────
33
  OPENAI_MODEL = "gpt-4o-mini"
34
  TOP_K_CHUNKS = 3
@@ -135,23 +136,32 @@ class SulithaAgent:
135
  self.name = "Sulitha Nulaksha Bandara"
136
  self.openai = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY", ""))
137
 
138
- logger.info("Loading sentence-transformer model...")
139
- self._embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
140
-
141
- logger.info(f"Embedding {len(DOCUMENTS)} knowledge chunks...")
142
  self._doc_chunks = DOCUMENTS
143
- embeddings = self._embedder.encode(DOCUMENTS, convert_to_numpy=True, show_progress_bar=False)
144
- embeddings = embeddings.astype(np.float32)
145
 
146
  dim = embeddings.shape[1]
147
  self._index = faiss.IndexFlatL2(dim)
148
  self._index.add(embeddings)
149
  logger.info(f"FAISS index built: {self._index.ntotal} vectors, dim={dim}")
150
 
 
 
 
 
 
 
 
 
 
 
 
151
  # ── RAG ───────────────────────────────────────────────────────────────────
152
 
153
- def _retrieve(self, query: str, k: int = TOP_K_CHUNKS) -> str:
154
- q_vec = self._embedder.encode([query], convert_to_numpy=True).astype(np.float32)
 
 
155
  _, idxs = self._index.search(q_vec, k)
156
  chunks = [self._doc_chunks[i] for i in idxs[0] if i < len(self._doc_chunks)]
157
  return "\n\n---\n\n".join(chunks)
@@ -198,7 +208,7 @@ class SulithaAgent:
198
  if session_store.is_over_limit(session_id):
199
  return SESSION_LIMIT_REPLY
200
 
201
- rag_context = self._retrieve(message)
202
  history = session_store.get_history(session_id)
203
 
204
  messages = (
 
22
  import faiss
23
  import requests
24
  from openai import AsyncOpenAI
 
25
 
26
  from knowledge import DOCUMENTS
27
  import session as session_store
28
 
29
  logger = logging.getLogger(__name__)
30
 
31
+ EMBED_MODEL = "text-embedding-3-small" # fast, cheap, 1536-dim
32
+
33
  # ── Config ────────────────────────────────────────────────────────────────────
34
  OPENAI_MODEL = "gpt-4o-mini"
35
  TOP_K_CHUNKS = 3
 
136
  self.name = "Sulitha Nulaksha Bandara"
137
  self.openai = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY", ""))
138
 
139
+ logger.info(f"Embedding {len(DOCUMENTS)} knowledge chunks via OpenAI...")
 
 
 
140
  self._doc_chunks = DOCUMENTS
141
+ embeddings = self._embed_texts(DOCUMENTS)
 
142
 
143
  dim = embeddings.shape[1]
144
  self._index = faiss.IndexFlatL2(dim)
145
  self._index.add(embeddings)
146
  logger.info(f"FAISS index built: {self._index.ntotal} vectors, dim={dim}")
147
 
148
+ # ── Embeddings (OpenAI) ───────────────────────────────────────────────────
149
+
150
+ def _embed_texts(self, texts: list[str]) -> np.ndarray:
151
+ """Embed a list of texts using OpenAI embeddings. Returns float32 array."""
152
+ # Use a sync client just for the startup index build
153
+ from openai import OpenAI as SyncOpenAI
154
+ client = SyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY", ""))
155
+ response = client.embeddings.create(model=EMBED_MODEL, input=texts)
156
+ vecs = np.array([d.embedding for d in response.data], dtype=np.float32)
157
+ return vecs
158
+
159
  # ── RAG ───────────────────────────────────────────────────────────────────
160
 
161
+ async def _retrieve(self, query: str, k: int = TOP_K_CHUNKS) -> str:
162
+ """Embed query via OpenAI, search FAISS, return top-k chunks."""
163
+ response = await self.openai.embeddings.create(model=EMBED_MODEL, input=[query])
164
+ q_vec = np.array([response.data[0].embedding], dtype=np.float32)
165
  _, idxs = self._index.search(q_vec, k)
166
  chunks = [self._doc_chunks[i] for i in idxs[0] if i < len(self._doc_chunks)]
167
  return "\n\n---\n\n".join(chunks)
 
208
  if session_store.is_over_limit(session_id):
209
  return SESSION_LIMIT_REPLY
210
 
211
+ rag_context = await self._retrieve(message)
212
  history = session_store.get_history(session_id)
213
 
214
  messages = (
requirements.txt CHANGED
@@ -3,7 +3,6 @@ fastapi==0.115.0
3
  uvicorn[standard]==0.30.6
4
  pydantic==2.8.2
5
  openai==1.51.0
6
- sentence-transformers==3.1.1
7
  faiss-cpu==1.9.0
8
  numpy==1.26.4
9
  httpx==0.27.2
 
3
  uvicorn[standard]==0.30.6
4
  pydantic==2.8.2
5
  openai==1.51.0
 
6
  faiss-cpu==1.9.0
7
  numpy==1.26.4
8
  httpx==0.27.2