File size: 2,333 Bytes
864d9af b8028fc 864d9af b8028fc 864d9af b8028fc 864d9af b8028fc 864d9af b8028fc 864d9af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
from fastapi import FastAPI
from pydantic import BaseModel
import faiss
import pickle
from sentence_transformers import SentenceTransformer
import numpy as np
from collections import Counter
import gzip
import uvicorn
# ===== CONFIG =====
INDEX_PATH = "faiss.index"
META_PATH = "metadata.pkl.gz"
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
CHUNK_SIZE = 2000 # طول كل chunk بال characters
# ===== LOAD FAISS INDEX =====
index = faiss.read_index(INDEX_PATH)
with gzip.open(META_PATH, "rb") as f:
meta = pickle.load(f)
texts = meta["texts"]
statuses = meta["statuses"]
# ===== LOAD MODEL =====
model = SentenceTransformer(MODEL_NAME)
# ===== INIT API =====
app = FastAPI(title="Text Embedding Predictor")
# ===== INPUT SCHEMA =====
class Query(BaseModel):
text: str
k: int = 5 # أعلى 5 مشابهين افتراضي
# ===== HELPER: split long text into chunks =====
def split_text(text, chunk_size=CHUNK_SIZE):
chunks = []
for i in range(0, len(text), chunk_size):
chunks.append(text[i:i+chunk_size])
return chunks
# ===== PREDICTION ROUTE =====
@app.post("/predict")
def predict(query: Query):
text_chunks = split_text(query.text)
all_top_statuses = []
all_results = []
for chunk in text_chunks:
# Escape backslashes
chunk = chunk.replace("\\", "\\\\")
# ===== EMBEDDING =====
q_emb = model.encode([chunk]).astype("float32")
distances, indices = index.search(q_emb, query.k)
top_statuses = []
results = []
for rank, idx in enumerate(indices[0]):
status = statuses[idx]
top_statuses.append(status)
results.append({
"rank": rank + 1,
"text": texts[idx],
"status": status,
"distance": float(distances[0][rank])
})
all_top_statuses.extend(top_statuses)
all_results.extend(results)
# ===== VOTING =====
vote = Counter(all_top_statuses).most_common(1)[0]
return {
"prediction": vote[0],
"votes": dict(Counter(all_top_statuses)),
"top_k": all_results[:query.k] # أعلى k من كل النتائج
}
# ===== RUN IF MAIN =====
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
|