|
|
from fastapi import FastAPI |
|
|
from pydantic import BaseModel |
|
|
import faiss |
|
|
import pickle |
|
|
from sentence_transformers import SentenceTransformer |
|
|
import numpy as np |
|
|
from collections import Counter |
|
|
import gzip |
|
|
import uvicorn |
|
|
|
|
|
|
|
|
INDEX_PATH = "faiss.index" |
|
|
META_PATH = "metadata.pkl.gz" |
|
|
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" |
|
|
CHUNK_SIZE = 2000 |
|
|
|
|
|
|
|
|
index = faiss.read_index(INDEX_PATH) |
|
|
|
|
|
with gzip.open(META_PATH, "rb") as f: |
|
|
meta = pickle.load(f) |
|
|
|
|
|
texts = meta["texts"] |
|
|
statuses = meta["statuses"] |
|
|
|
|
|
|
|
|
model = SentenceTransformer(MODEL_NAME) |
|
|
|
|
|
|
|
|
app = FastAPI(title="Text Embedding Predictor") |
|
|
|
|
|
|
|
|
class Query(BaseModel): |
|
|
text: str |
|
|
k: int = 5 |
|
|
|
|
|
|
|
|
def split_text(text, chunk_size=CHUNK_SIZE): |
|
|
chunks = [] |
|
|
for i in range(0, len(text), chunk_size): |
|
|
chunks.append(text[i:i+chunk_size]) |
|
|
return chunks |
|
|
|
|
|
|
|
|
@app.post("/predict") |
|
|
def predict(query: Query): |
|
|
text_chunks = split_text(query.text) |
|
|
all_top_statuses = [] |
|
|
all_results = [] |
|
|
|
|
|
for chunk in text_chunks: |
|
|
|
|
|
chunk = chunk.replace("\\", "\\\\") |
|
|
|
|
|
q_emb = model.encode([chunk]).astype("float32") |
|
|
distances, indices = index.search(q_emb, query.k) |
|
|
|
|
|
top_statuses = [] |
|
|
results = [] |
|
|
|
|
|
for rank, idx in enumerate(indices[0]): |
|
|
status = statuses[idx] |
|
|
top_statuses.append(status) |
|
|
results.append({ |
|
|
"rank": rank + 1, |
|
|
"text": texts[idx], |
|
|
"status": status, |
|
|
"distance": float(distances[0][rank]) |
|
|
}) |
|
|
|
|
|
all_top_statuses.extend(top_statuses) |
|
|
all_results.extend(results) |
|
|
|
|
|
|
|
|
vote = Counter(all_top_statuses).most_common(1)[0] |
|
|
|
|
|
return { |
|
|
"prediction": vote[0], |
|
|
"votes": dict(Counter(all_top_statuses)), |
|
|
"top_k": all_results[:query.k] |
|
|
} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) |
|
|
|