Spaces:
Sleeping
Sleeping
Commit
·
fca97ef
1
Parent(s):
053ffc5
use models
Browse files- data/paris-2024-faq.json +0 -0
- server.py +98 -3
data/paris-2024-faq.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
server.py
CHANGED
|
@@ -1,14 +1,32 @@
|
|
| 1 |
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from fastapi import FastAPI
|
| 4 |
from pydantic import BaseModel
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
|
|
|
|
|
|
|
| 7 |
logging.basicConfig()
|
| 8 |
logger = logging.getLogger(__name__)
|
| 9 |
logger.setLevel(logging.INFO)
|
| 10 |
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
class InputLoad(BaseModel):
|
| 13 |
question: str
|
| 14 |
|
|
@@ -17,7 +35,31 @@ class ResponseLoad(BaseModel):
|
|
| 17 |
answer: str
|
| 18 |
|
| 19 |
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
@app.get("/health")
|
|
@@ -26,5 +68,58 @@ def health_check():
|
|
| 26 |
|
| 27 |
|
| 28 |
@app.post("/answer/")
|
| 29 |
-
async def receive(input_load: InputLoad) -> ResponseLoad:
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
+
import json
|
| 3 |
+
from contextlib import asynccontextmanager
|
| 4 |
+
from typing import Any, List, Tuple
|
| 5 |
+
import random
|
| 6 |
|
| 7 |
from fastapi import FastAPI
|
| 8 |
from pydantic import BaseModel
|
| 9 |
+
from FlagEmbedding import BGEM3FlagModel, FlagReranker
|
| 10 |
+
from starlette.requests import Request
|
| 11 |
+
import torch
|
| 12 |
|
| 13 |
|
| 14 |
+
random.seed(42)
|
| 15 |
+
|
| 16 |
logging.basicConfig()
|
| 17 |
logger = logging.getLogger(__name__)
|
| 18 |
logger.setLevel(logging.INFO)
|
| 19 |
|
| 20 |
|
| 21 |
+
def get_data(model):
|
| 22 |
+
with open("data/paris-2024-faq.json") as f:
|
| 23 |
+
data = json.load(f)
|
| 24 |
+
data = [it for it in data if it['lang'] == 'en']
|
| 25 |
+
questions = [it['label'] for it in data]
|
| 26 |
+
q_embeddings = model[0].encode(questions, return_dense=False, return_sparse=False, return_colbert_vecs=True)
|
| 27 |
+
return q_embeddings['colbert_vecs'], questions, [it['body'] for it in data]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
class InputLoad(BaseModel):
|
| 31 |
question: str
|
| 32 |
|
|
|
|
| 35 |
answer: str
|
| 36 |
|
| 37 |
|
| 38 |
+
class ML(BaseModel):
|
| 39 |
+
retriever: Any
|
| 40 |
+
ranker: Any
|
| 41 |
+
data: Tuple[List[Any], List[str], List[str]]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def load_models(app: FastAPI) -> FastAPI:
|
| 45 |
+
retriever=BGEM3FlagModel('BAAI/bge-m3', use_fp16=True) ,
|
| 46 |
+
ranker=FlagReranker('BAAI/bge-reranker-v2-m3', use_fp16=True)
|
| 47 |
+
ml = ML(
|
| 48 |
+
retriever=retriever,
|
| 49 |
+
ranker=ranker,
|
| 50 |
+
data=get_data(retriever)
|
| 51 |
+
)
|
| 52 |
+
app.ml = ml
|
| 53 |
+
return app
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@asynccontextmanager
|
| 57 |
+
async def lifespan(app: FastAPI):
|
| 58 |
+
app = load_models(app=app)
|
| 59 |
+
yield
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
app = FastAPI(lifespan=lifespan)
|
| 63 |
|
| 64 |
|
| 65 |
@app.get("/health")
|
|
|
|
| 68 |
|
| 69 |
|
| 70 |
@app.post("/answer/")
|
| 71 |
+
async def receive(input_load: InputLoad, request: Request) -> ResponseLoad:
|
| 72 |
+
ml: ML = request.app.ml
|
| 73 |
+
candidate_indices, candidate_scores = get_candidates(input_load.question, ml)
|
| 74 |
+
answer_candidate, rank_score, retriever_score = rerank_candidates(input_load.question, candidate_indices, candidate_scores, ml)
|
| 75 |
+
answer = get_final_answer(answer_candidate, retriever_score)
|
| 76 |
+
return ResponseLoad(answer=answer)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def get_candidates(question, ml, topk=5):
|
| 80 |
+
question_emb = ml.retriever[0].encode([question], return_dense=False, return_sparse=False, return_colbert_vecs=True)
|
| 81 |
+
question_emb = question_emb['colbert_vecs'][0]
|
| 82 |
+
scores = [ml.retriever[0].colbert_score(question_emb, faq_emb) for faq_emb in ml.data[0]]
|
| 83 |
+
scores_tensor = torch.stack(scores)
|
| 84 |
+
top_values, top_indices = torch.topk(scores_tensor, topk)
|
| 85 |
+
return top_indices.tolist(), top_values.tolist()
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def rerank_candidates(question, indices, values, ml):
|
| 89 |
+
candidate_answers = [ml.data[2][_ind] for _ind in indices]
|
| 90 |
+
scores = ml.ranker.compute_score([[question, it] for it in candidate_answers])
|
| 91 |
+
rank_score = max(scores)
|
| 92 |
+
rank_ind = scores.index(rank_score)
|
| 93 |
+
retriever_score = values[rank_ind]
|
| 94 |
+
return candidate_answers[rank_ind], rank_score, retriever_score
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def get_final_answer(answer, retriever_score):
|
| 98 |
+
logger.info(f"Retriever score: {retriever_score}")
|
| 99 |
+
if retriever_score < 0.65:
|
| 100 |
+
# nothing relevant found!
|
| 101 |
+
return random.sample(NOT_FOUND_ANSWERS, k=1)[0]
|
| 102 |
+
elif retriever_score < 0.8:
|
| 103 |
+
# might be relevant, but let's be careful
|
| 104 |
+
return f"{random.sample(ROUGH_MATCH_INTROS, k=1)[0]}\n{answer}"
|
| 105 |
+
else:
|
| 106 |
+
# good match
|
| 107 |
+
return f"{random.sample(GOOD_MATCH_INTROS, k=1)[0]}\n{answer}\n{random.sample(GOOD_MATCH_ENDS, k=1)[0]}"
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
NOT_FOUND_ANSWERS = [
|
| 111 |
+
"I'm sorry, but I couldn't find any information related to your question in my knowledge base.",
|
| 112 |
+
"Apologies, but I don't have the information you're looking for at the moment.",
|
| 113 |
+
"I’m sorry, I couldn’t locate any relevant details in my current data.",
|
| 114 |
+
"Unfortunately, I wasn't able to find an answer to your query. Can I help with something else?",
|
| 115 |
+
"I'm afraid I don't have the information you need right now. Please feel free to ask another question.",
|
| 116 |
+
"Sorry, I couldn't find anything that matches your question in my knowledge base.",
|
| 117 |
+
"I apologize, but I wasn't able to retrieve information related to your query.",
|
| 118 |
+
"I'm sorry, but it looks like I don't have an answer for that. Is there anything else I can assist with?",
|
| 119 |
+
"Regrettably, I couldn't find the information you requested. Can I help you with anything else?",
|
| 120 |
+
"I’m sorry, but I don't have the details you're seeking in my knowledge database."
|
| 121 |
+
]
|
| 122 |
+
|
| 123 |
+
GOOD_MATCH_INTROS = ["Super!"]
|
| 124 |
+
GOOD_MATCH_ENDS = ["Hopes this helps!"]
|
| 125 |
+
ROUGH_MATCH_INTROS = ["Not sure if that answers your question!"]
|