shimaa22 commited on
Commit
864d9af
·
verified ·
1 Parent(s): 2b20eaf

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +67 -58
api.py CHANGED
@@ -1,58 +1,67 @@
1
- from fastapi import FastAPI
2
- from pydantic import BaseModel
3
- import faiss
4
- import pickle
5
- from sentence_transformers import SentenceTransformer
6
- import numpy as np
7
- from collections import Counter
8
-
9
- # ===== LOAD =====
10
- INDEX_PATH = "faiss.index"
11
- META_PATH = "metadata.pkl"
12
-
13
- index = faiss.read_index(INDEX_PATH)
14
-
15
- with open(META_PATH, "rb") as f:
16
- meta = pickle.load(f)
17
-
18
- texts = meta["texts"]
19
- statuses = meta["statuses"]
20
-
21
- model = SentenceTransformer(
22
- "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
23
- )
24
-
25
- # ===== API =====
26
- app = FastAPI()
27
-
28
- class Query(BaseModel):
29
- text: str
30
- k: int = 5
31
-
32
- @app.post("/predict")
33
- def predict(query: Query):
34
- q_emb = model.encode([query.text]).astype("float32")
35
- distances, indices = index.search(q_emb, query.k)
36
-
37
- top_statuses = []
38
- results = []
39
-
40
- for rank, idx in enumerate(indices[0]):
41
- status = statuses[idx]
42
- top_statuses.append(status)
43
-
44
- results.append({
45
- "rank": rank + 1,
46
- "text": texts[idx],
47
- "status": status,
48
- "distance": float(distances[0][rank])
49
- })
50
-
51
- # ===== VOTING =====
52
- vote = Counter(top_statuses).most_common(1)[0]
53
-
54
- return {
55
- "prediction": vote[0],
56
- "votes": dict(Counter(top_statuses)),
57
- "top_k": results
58
- }
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ import faiss
4
+ import pickle
5
+ from sentence_transformers import SentenceTransformer
6
+ import numpy as np
7
+ from collections import Counter
8
+ import gzip
9
+ import uvicorn
10
+
11
+ # ===== CONFIG =====
12
+ INDEX_PATH = "faiss.index"
13
+ META_PATH = "metadata.pkl.gz"
14
+ MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" # خفيف ومتاح للـ Free Space
15
+
16
+ # ===== LOAD FAISS INDEX =====
17
+ index = faiss.read_index(INDEX_PATH)
18
+
19
+ with gzip.open(META_PATH, "rb") as f:
20
+ meta = pickle.load(f)
21
+
22
+ texts = meta["texts"]
23
+ statuses = meta["statuses"]
24
+
25
+ # ===== LOAD MODEL =====
26
+ model = SentenceTransformer(MODEL_NAME)
27
+
28
+ # ===== INIT API =====
29
+ app = FastAPI(title="Text Embedding Predictor")
30
+
31
+ # ===== INPUT SCHEMA =====
32
+ class Query(BaseModel):
33
+ text: str
34
+ k: int = 5 # أعلى 5 مشابهين افتراضي
35
+
36
+ # ===== PREDICTION ROUTE =====
37
+ @app.post("/predict")
38
+ def predict(query: Query):
39
+ # ===== EMBEDDING =====
40
+ q_emb = model.encode([query.text]).astype("float32")
41
+ distances, indices = index.search(q_emb, query.k)
42
+
43
+ top_statuses = []
44
+ results = []
45
+
46
+ for rank, idx in enumerate(indices[0]):
47
+ status = statuses[idx]
48
+ top_statuses.append(status)
49
+ results.append({
50
+ "rank": rank + 1,
51
+ "text": texts[idx],
52
+ "status": status,
53
+ "distance": float(distances[0][rank])
54
+ })
55
+
56
+ # ===== VOTING =====
57
+ vote = Counter(top_statuses).most_common(1)[0]
58
+
59
+ return {
60
+ "prediction": vote[0],
61
+ "votes": dict(Counter(top_statuses)),
62
+ "top_k": results
63
+ }
64
+
65
+ # ===== RUN IF MAIN =====
66
+ if __name__ == "__main__":
67
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)