kamp0010 commited on
Commit
872bcb0
·
verified ·
1 Parent(s): 2e4865a

Rename app.py to main.py

Browse files
Files changed (2) hide show
  1. app.py +0 -464
  2. main.py +252 -0
app.py DELETED
@@ -1,464 +0,0 @@
1
- import os
2
- import builtins
3
-
4
- _real_input = builtins.input
5
- def _auto_yes(prompt=""):
6
- if any(kw in str(prompt).lower() for kw in ("custom code", "trust", "wish to run")):
7
- return "y"
8
- return _real_input(prompt)
9
- builtins.input = _auto_yes
10
-
11
- os.environ["TRUST_REMOTE_CODE"] = "1"
12
- os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
13
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
14
- os.environ["HF_HUB_VERBOSITY"] = "error"
15
-
16
- import streamlit as st
17
- import numpy as np
18
- import torch
19
- import re
20
- from transformers import AutoModel, AutoTokenizer
21
-
22
- # ─────────────────────────── Page config ──────────────────────────────────────
23
- st.set_page_config(
24
- page_title="pplx-embed · Semantic Search",
25
- page_icon="◈",
26
- layout="wide",
27
- initial_sidebar_state="expanded",
28
- )
29
-
30
- # ─────────────────────────── Global CSS ───────────────────────────────────────
31
- st.markdown("""
32
- <style>
33
- @import url('https://fonts.googleapis.com/css2?family=Syne:wght@400;600;700;800&family=JetBrains+Mono:wght@300;400;500&display=swap');
34
-
35
- /* ── Base ── */
36
- html, body, [data-testid="stAppViewContainer"] {
37
- background: #0c0e14 !important;
38
- color: #e8e4d9 !important;
39
- font-family: 'JetBrains Mono', monospace !important;
40
- }
41
- [data-testid="stSidebar"] {
42
- background: #10121a !important;
43
- border-right: 1px solid #1e2235 !important;
44
- }
45
- [data-testid="stSidebar"] * { color: #e8e4d9 !important; }
46
-
47
- /* ── Hide default Streamlit chrome ── */
48
- #MainMenu, footer, header { visibility: hidden; }
49
- .block-container { padding: 2rem 2.5rem 3rem !important; max-width: 1100px !important; }
50
-
51
- /* ── Hero header ── */
52
- .hero {
53
- display: flex;
54
- align-items: flex-end;
55
- gap: 1rem;
56
- margin-bottom: 0.25rem;
57
- }
58
- .hero-icon {
59
- font-size: 2.8rem;
60
- line-height: 1;
61
- color: #f5a623;
62
- font-family: 'Syne', sans-serif;
63
- }
64
- .hero-title {
65
- font-family: 'Syne', sans-serif;
66
- font-weight: 800;
67
- font-size: 2.4rem;
68
- letter-spacing: -0.04em;
69
- color: #f0ede6;
70
- line-height: 1;
71
- }
72
- .hero-title span { color: #f5a623; }
73
- .hero-sub {
74
- font-family: 'JetBrains Mono', monospace;
75
- font-size: 0.72rem;
76
- color: #5a6080;
77
- letter-spacing: 0.12em;
78
- text-transform: uppercase;
79
- margin-bottom: 2rem;
80
- margin-top: 0.3rem;
81
- }
82
- .divider {
83
- height: 1px;
84
- background: linear-gradient(90deg, #f5a623 0%, #f5a62322 40%, transparent 100%);
85
- margin-bottom: 2rem;
86
- }
87
-
88
- /* ── Upload zone ── */
89
- [data-testid="stFileUploader"] {
90
- background: #13161f !important;
91
- border: 1px dashed #2a2e42 !important;
92
- border-radius: 8px !important;
93
- transition: border-color 0.2s;
94
- }
95
- [data-testid="stFileUploader"]:hover {
96
- border-color: #f5a623 !important;
97
- }
98
- [data-testid="stFileUploader"] * { color: #7a80a0 !important; }
99
- [data-testid="stFileUploader"] label { color: #e8e4d9 !important; }
100
-
101
- /* ── Text input ── */
102
- [data-testid="stTextInput"] input {
103
- background: #13161f !important;
104
- border: 1px solid #2a2e42 !important;
105
- border-radius: 6px !important;
106
- color: #f0ede6 !important;
107
- font-family: 'JetBrains Mono', monospace !important;
108
- font-size: 0.9rem !important;
109
- padding: 0.75rem 1rem !important;
110
- transition: border-color 0.2s, box-shadow 0.2s;
111
- }
112
- [data-testid="stTextInput"] input:focus {
113
- border-color: #f5a623 !important;
114
- box-shadow: 0 0 0 3px #f5a62318 !important;
115
- outline: none !important;
116
- }
117
- [data-testid="stTextInput"] label {
118
- color: #7a80a0 !important;
119
- font-size: 0.7rem !important;
120
- letter-spacing: 0.1em !important;
121
- text-transform: uppercase !important;
122
- font-family: 'JetBrains Mono', monospace !important;
123
- }
124
-
125
- /* ── Button ── */
126
- [data-testid="stButton"] button {
127
- background: #f5a623 !important;
128
- color: #0c0e14 !important;
129
- font-family: 'Syne', sans-serif !important;
130
- font-weight: 700 !important;
131
- font-size: 0.85rem !important;
132
- letter-spacing: 0.08em !important;
133
- text-transform: uppercase !important;
134
- border: none !important;
135
- border-radius: 6px !important;
136
- padding: 0.6rem 1.8rem !important;
137
- cursor: pointer !important;
138
- transition: background 0.15s, transform 0.1s !important;
139
- }
140
- [data-testid="stButton"] button:hover {
141
- background: #ffc048 !important;
142
- transform: translateY(-1px) !important;
143
- }
144
- [data-testid="stButton"] button:active { transform: translateY(0) !important; }
145
- [data-testid="stButton"] button:disabled {
146
- background: #1e2235 !important;
147
- color: #3a3f55 !important;
148
- cursor: not-allowed !important;
149
- transform: none !important;
150
- }
151
-
152
- /* ── Sliders ── */
153
- [data-testid="stSlider"] > div > div > div > div {
154
- background: #f5a623 !important;
155
- }
156
- [data-testid="stSlider"] label {
157
- color: #7a80a0 !important;
158
- font-size: 0.7rem !important;
159
- letter-spacing: 0.08em !important;
160
- text-transform: uppercase !important;
161
- }
162
-
163
- /* ── Expander ── */
164
- [data-testid="stExpander"] {
165
- background: #13161f !important;
166
- border: 1px solid #1e2235 !important;
167
- border-radius: 6px !important;
168
- }
169
- [data-testid="stExpander"] summary {
170
- color: #7a80a0 !important;
171
- font-size: 0.75rem !important;
172
- letter-spacing: 0.08em !important;
173
- }
174
-
175
- /* ── Alerts / info ── */
176
- [data-testid="stAlert"] {
177
- background: #13161f !important;
178
- border-radius: 6px !important;
179
- border-left: 3px solid #f5a623 !important;
180
- font-family: 'JetBrains Mono', monospace !important;
181
- font-size: 0.82rem !important;
182
- }
183
-
184
- /* ── Spinner text ── */
185
- [data-testid="stSpinner"] p { color: #7a80a0 !important; font-size: 0.8rem !important; }
186
-
187
- /* ── Sidebar labels ── */
188
- .sidebar-label {
189
- font-size: 0.65rem;
190
- letter-spacing: 0.15em;
191
- text-transform: uppercase;
192
- color: #f5a623;
193
- font-family: 'Syne', sans-serif;
194
- font-weight: 700;
195
- margin-bottom: 1rem;
196
- margin-top: 0.5rem;
197
- }
198
- .sidebar-how {
199
- font-size: 0.72rem;
200
- color: #5a6080;
201
- line-height: 1.8;
202
- border-left: 2px solid #1e2235;
203
- padding-left: 0.8rem;
204
- margin-top: 0.5rem;
205
- }
206
- .sidebar-step { color: #f5a623; font-weight: 500; }
207
-
208
- /* ── Result cards ── */
209
- @keyframes fadeSlideIn {
210
- from { opacity: 0; transform: translateY(10px); }
211
- to { opacity: 1; transform: translateY(0); }
212
- }
213
- .result-card {
214
- background: #13161f;
215
- border: 1px solid #1e2235;
216
- border-radius: 8px;
217
- padding: 1.1rem 1.3rem;
218
- margin-bottom: 0.75rem;
219
- animation: fadeSlideIn 0.3s ease both;
220
- position: relative;
221
- overflow: hidden;
222
- transition: border-color 0.2s, transform 0.15s;
223
- }
224
- .result-card:hover {
225
- border-color: #f5a62355;
226
- transform: translateX(3px);
227
- }
228
- .result-card::before {
229
- content: '';
230
- position: absolute;
231
- left: 0; top: 0; bottom: 0;
232
- width: 3px;
233
- border-radius: 8px 0 0 8px;
234
- }
235
- .card-high::before { background: #4ade80; }
236
- .card-mid::before { background: #f5a623; }
237
- .card-low::before { background: #f87171; }
238
- .card-meta {
239
- display: flex;
240
- align-items: center;
241
- gap: 0.75rem;
242
- margin-bottom: 0.6rem;
243
- }
244
- .card-rank {
245
- font-family: 'Syne', sans-serif;
246
- font-weight: 800;
247
- font-size: 0.7rem;
248
- color: #3a3f55;
249
- letter-spacing: 0.1em;
250
- }
251
- .card-score-bar {
252
- flex: 1;
253
- height: 3px;
254
- background: #1e2235;
255
- border-radius: 99px;
256
- overflow: hidden;
257
- }
258
- .card-score-fill {
259
- height: 100%;
260
- border-radius: 99px;
261
- transition: width 0.6s cubic-bezier(.16,1,.3,1);
262
- }
263
- .card-score-num {
264
- font-family: 'JetBrains Mono', monospace;
265
- font-size: 0.7rem;
266
- font-weight: 500;
267
- letter-spacing: 0.05em;
268
- }
269
- .card-text {
270
- font-family: 'JetBrains Mono', monospace;
271
- font-size: 0.82rem;
272
- line-height: 1.75;
273
- color: #c8c4b8;
274
- }
275
- .results-header {
276
- font-family: 'Syne', sans-serif;
277
- font-weight: 700;
278
- font-size: 0.7rem;
279
- letter-spacing: 0.18em;
280
- text-transform: uppercase;
281
- color: #5a6080;
282
- margin-bottom: 1rem;
283
- margin-top: 1.5rem;
284
- }
285
- .index-badge {
286
- display: inline-flex;
287
- align-items: center;
288
- gap: 0.4rem;
289
- background: #13161f;
290
- border: 1px solid #1e2235;
291
- border-radius: 4px;
292
- padding: 0.3rem 0.7rem;
293
- font-size: 0.72rem;
294
- color: #7a80a0;
295
- margin-bottom: 1rem;
296
- }
297
- .index-badge span { color: #f5a623; font-weight: 600; }
298
- </style>
299
- """, unsafe_allow_html=True)
300
-
301
- # ─────────────────────────── Model loading ────────────────────────────────────
302
- @st.cache_resource(show_spinner="◈ Loading models…")
303
- def load_models():
304
- ctx_model = AutoModel.from_pretrained("perplexity-ai/pplx-embed-context-v1-0.6B", trust_remote_code=True)
305
- query_model = AutoModel.from_pretrained("perplexity-ai/pplx-embed-v1-0.6B", trust_remote_code=True)
306
- tokenizer = AutoTokenizer.from_pretrained("perplexity-ai/pplx-embed-v1-0.6B", trust_remote_code=True)
307
- ctx_model.eval(); query_model.eval()
308
- return ctx_model, query_model, tokenizer
309
-
310
- ctx_model, query_model, tokenizer = load_models()
311
-
312
- # ─────────────────────────── Encoding helpers ─────────────────────────────────
313
- def mean_pool(token_embeddings, attention_mask):
314
- mask = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
315
- return torch.sum(token_embeddings * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)
316
-
317
- def _encode(model, texts):
318
- if hasattr(model, "encode"):
319
- result = model.encode(texts)
320
- if isinstance(result, (list, tuple)):
321
- return np.vstack([np.array(r).flatten() for r in result])
322
- return np.array(result)
323
- encoded = tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="pt")
324
- with torch.no_grad():
325
- out = model(**encoded)
326
- return mean_pool(out.last_hidden_state, encoded["attention_mask"]).cpu().numpy()
327
-
328
- def embed_document_chunks(chunks):
329
- if hasattr(ctx_model, "encode"):
330
- return np.array(ctx_model.encode([chunks])[0])
331
- return _encode(ctx_model, chunks)
332
-
333
- def embed_query(query):
334
- return _encode(query_model, [query])[0].flatten()
335
-
336
- def chunk_text(text, chunk_size=3, overlap=1):
337
- sentences = re.split(r'(?<=[.!?])\s+', text.strip())
338
- sentences = [s.strip() for s in sentences if s.strip()]
339
- chunks, i = [], 0
340
- while i < len(sentences):
341
- chunks.append(" ".join(sentences[i : i + chunk_size]))
342
- i += max(1, chunk_size - overlap)
343
- return chunks
344
-
345
- def cosine_sim(a, b):
346
- na, nb = np.linalg.norm(a), np.linalg.norm(b)
347
- return float(np.dot(a, b) / (na * nb)) if na and nb else 0.0
348
-
349
- def search(query, chunks, embeddings, top_k=5):
350
- q = embed_query(query)
351
- scores = [cosine_sim(q, embeddings[i]) for i in range(len(chunks))]
352
- ranked = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)
353
- return [(chunks[idx], score) for idx, score in ranked[:top_k]]
354
-
355
- # ─────────────────────────── Sidebar ──────────────────────────────────────────
356
- with st.sidebar:
357
- st.markdown('<div class="sidebar-label">◈ Configuration</div>', unsafe_allow_html=True)
358
- chunk_size = st.slider("Sentences per chunk", 1, 8, 3)
359
- overlap = st.slider("Sentence overlap", 0, 4, 1)
360
- top_k = st.slider("Results to show", 1, 10, 5)
361
- st.markdown("---")
362
- st.markdown('<div class="sidebar-label">How it works</div>', unsafe_allow_html=True)
363
- st.markdown("""
364
- <div class="sidebar-how">
365
- <div><span class="sidebar-step">01 ·</span> File split into overlapping sentence chunks</div>
366
- <div><span class="sidebar-step">02 ·</span> Chunks embedded as one document — each chunk sees its neighbours</div>
367
- <div><span class="sidebar-step">03 ·</span> Query embedded with the standalone model</div>
368
- <div><span class="sidebar-step">04 ·</span> Cosine similarity ranks results</div>
369
- </div>
370
- """, unsafe_allow_html=True)
371
- st.markdown("---")
372
- st.markdown("""
373
- <div style="font-size:0.65rem;color:#3a3f55;line-height:1.6;">
374
- context model · pplx-embed-context-v1-0.6B<br>
375
- query model &nbsp;· pplx-embed-v1-0.6B<br>
376
- dim · 1024 · int8 · cosine
377
- </div>
378
- """, unsafe_allow_html=True)
379
-
380
- # ─────────────────────────── Main UI ──────────────────────────────────────────
381
- st.markdown("""
382
- <div class="hero">
383
- <div class="hero-icon">◈</div>
384
- <div class="hero-title">pplx<span>·</span>search</div>
385
- </div>
386
- <div class="hero-sub">contextual semantic search · perplexity embed v1</div>
387
- <div class="divider"></div>
388
- """, unsafe_allow_html=True)
389
-
390
- uploaded = st.file_uploader("Drop a document to index", type=["txt", "md"], label_visibility="visible")
391
-
392
- if uploaded:
393
- raw_text = uploaded.read().decode("utf-8", errors="replace")
394
-
395
- with st.expander(f"Preview · {uploaded.name}", expanded=False):
396
- st.code(raw_text[:4000] + ("…" if len(raw_text) > 4000 else ""), language=None)
397
-
398
- cache_key = (uploaded.name, uploaded.size, chunk_size, overlap)
399
- if st.session_state.get("cache_key") != cache_key:
400
- with st.spinner("Embedding document chunks…"):
401
- chunks = chunk_text(raw_text, chunk_size=chunk_size, overlap=overlap)
402
- embeddings = embed_document_chunks(chunks)
403
- st.session_state.update(cache_key=cache_key, chunks=chunks, embeddings=embeddings)
404
- else:
405
- chunks = st.session_state["chunks"]
406
- embeddings = st.session_state["embeddings"]
407
-
408
- chunk_count = len(chunks)
409
- st.markdown(
410
- f'<div class="index-badge">◈ indexed &nbsp;<span>{chunk_count} chunks</span>&nbsp; from &nbsp;<span>{uploaded.name}</span></div>',
411
- unsafe_allow_html=True,
412
- )
413
-
414
- col1, col2 = st.columns([4, 1])
415
- with col1:
416
- query = st.text_input("query", placeholder="Ask anything about the document…", label_visibility="collapsed")
417
- with col2:
418
- search_btn = st.button("Search ↗", disabled=not (query or "").strip(), use_container_width=True)
419
-
420
- if search_btn and query.strip():
421
- with st.spinner("Searching…"):
422
- results = search(query, chunks, embeddings, top_k=top_k)
423
-
424
- st.markdown('<div class="results-header">— Results</div>', unsafe_allow_html=True)
425
-
426
- for rank, (chunk_txt, score) in enumerate(results, 1):
427
- pct = score * 100
428
- if pct >= 60:
429
- card_cls, fill_color, score_color = "card-high", "#4ade80", "#4ade80"
430
- elif pct >= 35:
431
- card_cls, fill_color, score_color = "card-mid", "#f5a623", "#f5a623"
432
- else:
433
- card_cls, fill_color, score_color = "card-low", "#f87171", "#f87171"
434
-
435
- delay = (rank - 1) * 0.07
436
- st.markdown(f"""
437
- <div class="result-card {card_cls}" style="animation-delay:{delay}s">
438
- <div class="card-meta">
439
- <div class="card-rank">#{rank:02d}</div>
440
- <div class="card-score-bar">
441
- <div class="card-score-fill" style="width:{min(pct,100):.1f}%;background:{fill_color};"></div>
442
- </div>
443
- <div class="card-score-num" style="color:{score_color}">{pct:.1f}%</div>
444
- </div>
445
- <div class="card-text">{chunk_txt}</div>
446
- </div>
447
- """, unsafe_allow_html=True)
448
-
449
- else:
450
- st.markdown("""
451
- <div style="
452
- margin-top: 3rem;
453
- border: 1px dashed #1e2235;
454
- border-radius: 10px;
455
- padding: 3rem 2rem;
456
- text-align: center;
457
- color: #3a3f55;
458
- font-size: 0.8rem;
459
- letter-spacing: 0.08em;
460
- ">
461
- <div style="font-size:2.5rem;margin-bottom:1rem;opacity:0.3">◈</div>
462
- Upload a <code style="color:#f5a62366">.txt</code> or <code style="color:#f5a62366">.md</code> file to begin indexing
463
- </div>
464
- """, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import builtins
3
+
4
+ # ── Auto-answer transformers custom code prompt ────────────────────────────────
5
+ _real_input = builtins.input
6
+ def _auto_yes(prompt=""):
7
+ if any(kw in str(prompt).lower() for kw in ("custom code", "trust", "wish to run")):
8
+ return "y"
9
+ return _real_input(prompt)
10
+ builtins.input = _auto_yes
11
+
12
+ os.environ["TRUST_REMOTE_CODE"] = "1"
13
+ os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
14
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
15
+ os.environ["HF_HUB_VERBOSITY"] = "error"
16
+
17
+ import re
18
+ import numpy as np
19
+ import torch
20
+ from contextlib import asynccontextmanager
21
+ from typing import Annotated
22
+
23
+ from fastapi import FastAPI, HTTPException, UploadFile, File, Form
24
+ from fastapi.middleware.cors import CORSMiddleware
25
+ from pydantic import BaseModel, Field
26
+ from transformers import AutoModel, AutoTokenizer
27
+
28
+
29
+ # ─────────────────────────── Models (loaded once at startup) ──────────────────
30
+ models: dict = {}
31
+
32
+ @asynccontextmanager
33
+ async def lifespan(app: FastAPI):
34
+ print("Loading embedding models…")
35
+ ctx_model = AutoModel.from_pretrained("perplexity-ai/pplx-embed-context-v1-0.6B", trust_remote_code=True)
36
+ query_model = AutoModel.from_pretrained("perplexity-ai/pplx-embed-v1-0.6B", trust_remote_code=True)
37
+ tokenizer = AutoTokenizer.from_pretrained("perplexity-ai/pplx-embed-v1-0.6B", trust_remote_code=True)
38
+ ctx_model.eval()
39
+ query_model.eval()
40
+ models["ctx"] = ctx_model
41
+ models["query"] = query_model
42
+ models["tokenizer"] = tokenizer
43
+ print("Models ready.")
44
+ yield
45
+ models.clear()
46
+
47
+
48
+ # ─────────────────────────── App ──────────────────────────────────────────────
49
+ app = FastAPI(
50
+ title="pplx-embed Semantic Search API",
51
+ description=(
52
+ "Upload a document and search it semantically using "
53
+ "perplexity-ai/pplx-embed-context-v1-0.6B + pplx-embed-v1-0.6B."
54
+ ),
55
+ version="1.0.0",
56
+ lifespan=lifespan,
57
+ )
58
+
59
+ app.add_middleware(
60
+ CORSMiddleware,
61
+ allow_origins=["*"],
62
+ allow_methods=["*"],
63
+ allow_headers=["*"],
64
+ )
65
+
66
+
67
+ # ─────────────────────────── Helpers ──────────────────────────────────────────
68
+ def mean_pool(token_embeddings: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
69
+ mask = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
70
+ return torch.sum(token_embeddings * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)
71
+
72
+
73
+ def _encode(model, texts: list[str]) -> np.ndarray:
74
+ if hasattr(model, "encode"):
75
+ result = model.encode(texts)
76
+ if isinstance(result, (list, tuple)):
77
+ return np.vstack([np.array(r).flatten() for r in result])
78
+ return np.array(result)
79
+ tokenizer = models["tokenizer"]
80
+ encoded = tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="pt")
81
+ with torch.no_grad():
82
+ out = model(**encoded)
83
+ return mean_pool(out.last_hidden_state, encoded["attention_mask"]).cpu().numpy()
84
+
85
+
86
+ def embed_chunks(chunks: list[str]) -> np.ndarray:
87
+ ctx = models["ctx"]
88
+ if hasattr(ctx, "encode"):
89
+ return np.array(ctx.encode([chunks])[0])
90
+ return _encode(ctx, chunks)
91
+
92
+
93
+ def embed_query_text(query: str) -> np.ndarray:
94
+ return _encode(models["query"], [query])[0].flatten()
95
+
96
+
97
+ def chunk_text(text: str, chunk_size: int = 3, overlap: int = 1) -> list[str]:
98
+ sentences = re.split(r'(?<=[.!?])\s+', text.strip())
99
+ sentences = [s.strip() for s in sentences if s.strip()]
100
+ chunks, i = [], 0
101
+ while i < len(sentences):
102
+ chunks.append(" ".join(sentences[i : i + chunk_size]))
103
+ i += max(1, chunk_size - overlap)
104
+ return chunks
105
+
106
+
107
+ def cosine_sim(a: np.ndarray, b: np.ndarray) -> float:
108
+ na, nb = np.linalg.norm(a), np.linalg.norm(b)
109
+ return float(np.dot(a, b) / (na * nb)) if na and nb else 0.0
110
+
111
+
112
+ # ─────────────────────────── In-memory document store ─────────────────────────
113
+ # Maps doc_id → { chunks: list[str], embeddings: np.ndarray }
114
+ store: dict[str, dict] = {}
115
+
116
+
117
+ # ─────────────────────────── Schemas ──────────────────────────────────────────
118
+ class IndexResponse(BaseModel):
119
+ doc_id: str
120
+ chunks_indexed: int
121
+ message: str
122
+
123
+ class SearchRequest(BaseModel):
124
+ doc_id: str = Field(..., description="ID returned by /index")
125
+ query: str = Field(..., description="Natural language question")
126
+ top_k: int = Field(5, ge=1, le=20)
127
+
128
+ class SearchResult(BaseModel):
129
+ rank: int
130
+ score: float
131
+ text: str
132
+
133
+ class SearchResponse(BaseModel):
134
+ doc_id: str
135
+ query: str
136
+ results: list[SearchResult]
137
+
138
+ class EmbedRequest(BaseModel):
139
+ texts: list[str] = Field(..., description="List of strings to embed independently")
140
+
141
+ class EmbedResponse(BaseModel):
142
+ embeddings: list[list[float]]
143
+ dimensions: int
144
+
145
+
146
+ # ─────────────────────────── Routes ───────────────────────────────────────────
147
+ @app.get("/", tags=["health"])
148
+ def root():
149
+ return {"status": "ok", "docs": "/docs"}
150
+
151
+
152
+ @app.get("/health", tags=["health"])
153
+ def health():
154
+ return {"status": "ok", "models_loaded": bool(models)}
155
+
156
+
157
+ @app.post("/index", response_model=IndexResponse, tags=["search"])
158
+ async def index_document(
159
+ file: Annotated[UploadFile, File(description=".txt or .md file to index")],
160
+ doc_id: Annotated[str, Form(description="Unique ID for this document")] = "",
161
+ chunk_size: Annotated[int, Form()] = 3,
162
+ overlap: Annotated[int, Form()] = 1,
163
+ ):
164
+ """
165
+ Upload a .txt or .md file and embed it. Returns a doc_id you use in /search.
166
+ If doc_id is empty, the filename (without extension) is used.
167
+ """
168
+ if not models:
169
+ raise HTTPException(503, "Models not loaded yet — please retry in a few seconds.")
170
+
171
+ content = await file.read()
172
+ try:
173
+ text = content.decode("utf-8")
174
+ except UnicodeDecodeError:
175
+ text = content.decode("latin-1")
176
+
177
+ resolved_id = doc_id.strip() or os.path.splitext(file.filename or "doc")[0]
178
+
179
+ chunks = chunk_text(text, chunk_size=chunk_size, overlap=overlap)
180
+ if not chunks:
181
+ raise HTTPException(400, "Document produced no text chunks. Check the file contents.")
182
+
183
+ embeddings = embed_chunks(chunks)
184
+ store[resolved_id] = {"chunks": chunks, "embeddings": embeddings}
185
+
186
+ return IndexResponse(
187
+ doc_id=resolved_id,
188
+ chunks_indexed=len(chunks),
189
+ message=f"Document '{resolved_id}' indexed successfully.",
190
+ )
191
+
192
+
193
+ @app.post("/search", response_model=SearchResponse, tags=["search"])
194
+ def search_document(req: SearchRequest):
195
+ """
196
+ Search a previously indexed document by doc_id.
197
+ """
198
+ if req.doc_id not in store:
199
+ raise HTTPException(404, f"doc_id '{req.doc_id}' not found. Call /index first.")
200
+
201
+ doc = store[req.doc_id]
202
+ chunks = doc["chunks"]
203
+ embs = doc["embeddings"]
204
+ q = embed_query_text(req.query)
205
+ scores = [cosine_sim(q, embs[i]) for i in range(len(chunks))]
206
+ ranked = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)[: req.top_k]
207
+
208
+ return SearchResponse(
209
+ doc_id=req.doc_id,
210
+ query=req.query,
211
+ results=[
212
+ SearchResult(rank=i + 1, score=round(score, 4), text=chunks[idx])
213
+ for i, (idx, score) in enumerate(ranked)
214
+ ],
215
+ )
216
+
217
+
218
+ @app.post("/embed", response_model=EmbedResponse, tags=["embeddings"])
219
+ def embed_texts(req: EmbedRequest):
220
+ """
221
+ Embed arbitrary texts with the query model. Returns raw float embeddings.
222
+ """
223
+ if not models:
224
+ raise HTTPException(503, "Models not loaded yet.")
225
+ if len(req.texts) > 64:
226
+ raise HTTPException(400, "Maximum 64 texts per request.")
227
+
228
+ embs = _encode(models["query"], req.texts)
229
+ return EmbedResponse(
230
+ embeddings=embs.tolist(),
231
+ dimensions=embs.shape[1],
232
+ )
233
+
234
+
235
+ @app.get("/documents", tags=["search"])
236
+ def list_documents():
237
+ """List all currently indexed document IDs."""
238
+ return {
239
+ "documents": [
240
+ {"doc_id": k, "chunks": len(v["chunks"])}
241
+ for k, v in store.items()
242
+ ]
243
+ }
244
+
245
+
246
+ @app.delete("/documents/{doc_id}", tags=["search"])
247
+ def delete_document(doc_id: str):
248
+ """Remove a document from the index."""
249
+ if doc_id not in store:
250
+ raise HTTPException(404, f"doc_id '{doc_id}' not found.")
251
+ del store[doc_id]
252
+ return {"deleted": doc_id}