1MR commited on
Commit
353ede8
Β·
verified Β·
1 Parent(s): 3378eaa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -43
app.py CHANGED
@@ -1,23 +1,22 @@
1
  """
2
  KG Embedding Server β€” FastAPI on HuggingFace Spaces
3
- يشΨͺΨΊΩ„ ΩƒΩ€ REST API، ΩŠΨ­Ω…Ω„ Ψ§Ω„Ω…ΩˆΨ―ΩŠΩ„ Ω…Ψ±Ψ© واحدة Ψ¨Ψ³
4
  """
5
 
6
- from fastapi import FastAPI, HTTPException
7
  from pydantic import BaseModel
8
- from typing import Optional
9
- import math
10
- import re
11
- import contextlib
12
  from io import StringIO
13
 
14
  app = FastAPI(title="KG Embedding Server")
15
 
16
  # ══════════════════════════════════════════════════════
17
- # GLOBALS β€” Ψ§Ω„Ω…ΩˆΨ―ΩŠΩ„ بيΨͺΨ­Ω…Ω„ Ω…Ψ±Ψ© واحدة ΨΉΩ†Ψ― startup
18
  # ══════════════════════════════════════════════════════
19
  _model = None
20
  _use_st = False
 
21
 
22
  @app.on_event("startup")
23
  def load_model():
@@ -30,24 +29,35 @@ def load_model():
30
  print("[Server] sentence-transformers loaded βœ“")
31
  except Exception as e:
32
  print(f"[Server] ST unavailable: {e}")
33
- _use_st = False
34
 
35
  # ══════════════════════════════════════════════════════
36
- # REQUEST / RESPONSE MODELS
37
  # ══════════════════════════════════════════════════════
38
- class EmbedRequest(BaseModel):
39
- texts: list[str]
 
 
40
 
41
- class EmbedResponse(BaseModel):
42
- embeddings: list[list[float]]
43
- model: str = "all-MiniLM-L6-v2"
44
 
45
- class SimilarityRequest(BaseModel):
 
 
 
 
 
46
  query: str
47
- candidates: list[str] # Ψ§Ω„Ω†Ψ΅ΩˆΨ΅ Ψ§Ω„Ω„ΩŠ Ω‡Ω†Ω‚ΩŠΨ³ ΨΉΩ„ΩŠΩ‡Ψ§
 
 
 
 
48
 
49
- class SimilarityResponse(BaseModel):
50
- scores: list[float] # cosine similarity Ω„ΩƒΩ„ candidate
 
51
 
52
  class HealthResponse(BaseModel):
53
  status: str
@@ -61,37 +71,66 @@ class HealthResponse(BaseModel):
61
  def health():
62
  return {"status": "ok", "model_loaded": _use_st}
63
 
64
- @app.post("/embed", response_model=EmbedResponse)
65
- def embed(req: EmbedRequest):
 
 
 
 
 
 
 
66
  if not _use_st or _model is None:
67
- raise HTTPException(503, "Model not loaded")
68
- if not req.texts:
69
- return EmbedResponse(embeddings=[])
 
 
 
 
70
  with contextlib.redirect_stdout(StringIO()), contextlib.redirect_stderr(StringIO()):
71
- vecs = _model.encode(req.texts, show_progress_bar=False)
72
- return EmbedResponse(embeddings=[v.tolist() for v in vecs])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
- @app.post("/similarity", response_model=SimilarityResponse)
75
- def similarity(req: SimilarityRequest):
76
- """Query + candidates β†’ cosine scores (Ψ§Ω„Ψ£Ψ³Ψ±ΨΉ Ω„Ωˆ عايز ΨͺΨ±ΨͺΨ¨ Ω†Ψͺايج)"""
77
- if not _use_st or _model is None:
78
- raise HTTPException(503, "Model not loaded")
79
- texts = [req.query] + req.candidates
80
  with contextlib.redirect_stdout(StringIO()), contextlib.redirect_stderr(StringIO()):
81
- vecs = _model.encode(texts, show_progress_bar=False)
82
- qvec = vecs[0]
83
- scores = [_cosine(qvec, v) for v in vecs[1:]]
84
- return SimilarityResponse(scores=scores)
 
 
 
 
 
 
 
 
 
 
85
 
86
- def _cosine(a, b) -> float:
87
- dot = sum(x * y for x, y in zip(a, b))
88
- na = math.sqrt(sum(x * x for x in a)) or 1e-9
89
- nb = math.sqrt(sum(x * x for x in b)) or 1e-9
90
- return dot / (na * nb)
91
 
92
- # ══════════════════════════════════════════════════════
93
- # PING β€” Ω„Ω…Ω†ΨΉ Ψ§Ω„Ω€ Space Ω…Ω† Ψ§Ω„Ω†ΩˆΩ… (Ψ§ΨΉΩ…Ω„ cron job يبعΨͺΩ‡ ΩƒΩ„ 4 Ψ―Ω‚Ψ§ΩŠΩ‚)
94
- # ══════════════════════════════════════════════════════
95
  @app.get("/ping")
96
  def ping():
97
  return {"pong": True}
 
1
  """
2
  KG Embedding Server β€” FastAPI on HuggingFace Spaces
3
+ Session-based: client syncs once β†’ server holds FAISS index in RAM
4
  """
5
 
6
+ from fastapi import FastAPI
7
  from pydantic import BaseModel
8
+ from typing import Dict, Any, Optional
9
+ import math, contextlib, uuid
 
 
10
  from io import StringIO
11
 
12
  app = FastAPI(title="KG Embedding Server")
13
 
14
  # ══════════════════════════════════════════════════════
15
+ # GLOBALS
16
  # ══════════════════════════════════════════════════════
17
  _model = None
18
  _use_st = False
19
+ _sessions: Dict[str, Dict[str, Any]] = {} # session_id β†’ {nodes, index}
20
 
21
  @app.on_event("startup")
22
  def load_model():
 
29
  print("[Server] sentence-transformers loaded βœ“")
30
  except Exception as e:
31
  print(f"[Server] ST unavailable: {e}")
 
32
 
33
  # ══════════════════════════════════════════════════════
34
+ # MODELS
35
  # ══════════════════════════════════════════════════════
36
+ class NodePayload(BaseModel):
37
+ id: int
38
+ title: str
39
+ content: str
40
 
41
+ class SyncRequest(BaseModel):
42
+ session_id: str
43
+ nodes: list[NodePayload]
44
 
45
+ class SyncResponse(BaseModel):
46
+ status: str
47
+ count: int
48
+
49
+ class SearchRequest(BaseModel):
50
+ session_id: str
51
  query: str
52
+ top_k: int = 8
53
+
54
+ class SearchResult(BaseModel):
55
+ node_id: int
56
+ score: float
57
 
58
+ class SearchResponse(BaseModel):
59
+ results: list[SearchResult]
60
+ session_missing: bool = False
61
 
62
  class HealthResponse(BaseModel):
63
  status: str
 
71
  def health():
72
  return {"status": "ok", "model_loaded": _use_st}
73
 
74
+ @app.post("/sync", response_model=SyncResponse)
75
+ def sync(req: SyncRequest):
76
+ """Client uploads nodes once β†’ server builds FAISS index in RAM."""
77
+ try:
78
+ import faiss
79
+ import numpy as np
80
+ except ImportError:
81
+ raise Exception("faiss not installed on server")
82
+
83
  if not _use_st or _model is None:
84
+ return SyncResponse(status="no_model", count=0)
85
+
86
+ texts = [f"{n.title} {n.content}" for n in req.nodes]
87
+ if not texts:
88
+ _sessions[req.session_id] = {"nodes": [], "index": None}
89
+ return SyncResponse(status="ok", count=0)
90
+
91
  with contextlib.redirect_stdout(StringIO()), contextlib.redirect_stderr(StringIO()):
92
+ vecs = _model.encode(texts, show_progress_bar=False, normalize_embeddings=True)
93
+
94
+ vecs = np.array(vecs, dtype="float32")
95
+ dim = vecs.shape[1]
96
+
97
+ index = faiss.IndexFlatIP(dim) # cosine sim (vectors already normalized)
98
+ index.add(vecs)
99
+
100
+ _sessions[req.session_id] = {
101
+ "nodes": [n.dict() for n in req.nodes],
102
+ "index": index,
103
+ }
104
+ print(f"[Server] Session {req.session_id[:8]}… β†’ {len(texts)} nodes indexed βœ“")
105
+ return SyncResponse(status="ok", count=len(texts))
106
+
107
+
108
+ @app.post("/search", response_model=SearchResponse)
109
+ def search(req: SearchRequest):
110
+ """Embed query only, FAISS search against cached index."""
111
+ import numpy as np
112
+
113
+ session = _sessions.get(req.session_id)
114
+ if not session or session["index"] is None:
115
+ return SearchResponse(results=[], session_missing=True)
116
 
 
 
 
 
 
 
117
  with contextlib.redirect_stdout(StringIO()), contextlib.redirect_stderr(StringIO()):
118
+ qvec = _model.encode([req.query], normalize_embeddings=True)
119
+
120
+ qvec = np.array(qvec, dtype="float32")
121
+ k = min(req.top_k, len(session["nodes"]))
122
+ scores, indices = session["index"].search(qvec, k)
123
+
124
+ results = []
125
+ for score, idx in zip(scores[0], indices[0]):
126
+ if idx >= 0:
127
+ results.append(SearchResult(
128
+ node_id=session["nodes"][idx]["id"],
129
+ score=float(score)
130
+ ))
131
+ return SearchResponse(results=results)
132
 
 
 
 
 
 
133
 
 
 
 
134
  @app.get("/ping")
135
  def ping():
136
  return {"pong": True}