Spaces:
Runtime error
Runtime error
development branch (#7)
Browse files* fix relative import
* add embeddings requirement
* update openai embeddings requirements...
* format responses appropriately
* add markdown response
* Fix newline formatting
* add threshold and top_k
* update response
* fix merge conflict
- buster/chatbot.py +41 -6
buster/chatbot.py
CHANGED
|
@@ -12,13 +12,16 @@ logging.basicConfig(level=logging.INFO)
|
|
| 12 |
|
| 13 |
|
| 14 |
# search through the reviews for a specific product
|
| 15 |
-
def rank_documents(df: pd.DataFrame, query: str, top_k: int =
|
| 16 |
product_embedding = get_embedding(
|
| 17 |
query,
|
| 18 |
engine=EMBEDDING_MODEL,
|
| 19 |
)
|
| 20 |
df["similarity"] = df.embedding.apply(lambda x: cosine_similarity(x, product_embedding))
|
| 21 |
|
|
|
|
|
|
|
|
|
|
| 22 |
if top_k == -1:
|
| 23 |
# return all results
|
| 24 |
n = len(df)
|
|
@@ -28,13 +31,43 @@ def rank_documents(df: pd.DataFrame, query: str, top_k: int = 3) -> pd.DataFrame
|
|
| 28 |
|
| 29 |
|
| 30 |
def engineer_prompt(question: str, documents: list[str]) -> str:
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
-
|
|
|
|
| 35 |
# rank the documents, get the highest scoring doc and generate the prompt
|
| 36 |
-
candidates = rank_documents(df, query=question, top_k=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
documents = candidates.text.to_list()
|
|
|
|
| 38 |
prompt = engineer_prompt(question, documents)
|
| 39 |
|
| 40 |
logger.info(f"querying GPT...")
|
|
@@ -58,12 +91,14 @@ def answer_question(question: str, df) -> str:
|
|
| 58 |
GPT Response:\n{response_text}
|
| 59 |
"""
|
| 60 |
)
|
| 61 |
-
return response_text
|
|
|
|
| 62 |
except Exception as e:
|
| 63 |
import traceback
|
| 64 |
|
| 65 |
logging.error(traceback.format_exc())
|
| 66 |
-
|
|
|
|
| 67 |
|
| 68 |
|
| 69 |
def load_embeddings(path: str) -> pd.DataFrame:
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
# search through the reviews for a specific product
|
| 15 |
+
def rank_documents(df: pd.DataFrame, query: str, top_k: int = 1, thresh: float = None) -> pd.DataFrame:
|
| 16 |
product_embedding = get_embedding(
|
| 17 |
query,
|
| 18 |
engine=EMBEDDING_MODEL,
|
| 19 |
)
|
| 20 |
df["similarity"] = df.embedding.apply(lambda x: cosine_similarity(x, product_embedding))
|
| 21 |
|
| 22 |
+
if thresh:
|
| 23 |
+
df = df[df.similarity > thresh]
|
| 24 |
+
|
| 25 |
if top_k == -1:
|
| 26 |
# return all results
|
| 27 |
n = len(df)
|
|
|
|
| 31 |
|
| 32 |
|
| 33 |
def engineer_prompt(question: str, documents: list[str]) -> str:
|
| 34 |
+
documents_str = " ".join(documents)
|
| 35 |
+
if len(documents_str) > 3000:
|
| 36 |
+
logger.info("truncating documents to fit...")
|
| 37 |
+
documents_str = documents_str[0:3000]
|
| 38 |
+
return documents_str + "\nNow answer the following question:\n" + question
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def format_response(response_text, sources_url=None):
|
| 42 |
+
|
| 43 |
+
response = f"{response_text}\n"
|
| 44 |
+
|
| 45 |
+
if sources_url:
|
| 46 |
+
response += f"<br><br>Here are the sources I used to answer your question:\n"
|
| 47 |
+
for url in sources_url:
|
| 48 |
+
response += f"<br>[{url}]({url})\n"
|
| 49 |
|
| 50 |
+
response += "<br><br>"
|
| 51 |
+
response += """
|
| 52 |
+
```
|
| 53 |
+
I'm a bot 🤖 and not always perfect.
|
| 54 |
+
For more info, view the full documentation here (https://docs.mila.quebec/) or contact support@mila.quebec
|
| 55 |
+
```
|
| 56 |
+
"""
|
| 57 |
+
return response
|
| 58 |
|
| 59 |
+
|
| 60 |
+
def answer_question(question: str, df, top_k: int = 1, thresh: float = None) -> str:
|
| 61 |
# rank the documents, get the highest scoring doc and generate the prompt
|
| 62 |
+
candidates = rank_documents(df, query=question, top_k=top_k, thresh=thresh)
|
| 63 |
+
|
| 64 |
+
logger.info(f"candidate responses: {candidates}")
|
| 65 |
+
|
| 66 |
+
if len(candidates) == 0:
|
| 67 |
+
return format_response("I did not find any relevant documentation related to your question.")
|
| 68 |
+
|
| 69 |
documents = candidates.text.to_list()
|
| 70 |
+
sources_url = candidates.url.to_list()
|
| 71 |
prompt = engineer_prompt(question, documents)
|
| 72 |
|
| 73 |
logger.info(f"querying GPT...")
|
|
|
|
| 91 |
GPT Response:\n{response_text}
|
| 92 |
"""
|
| 93 |
)
|
| 94 |
+
return format_response(response_text, sources_url)
|
| 95 |
+
|
| 96 |
except Exception as e:
|
| 97 |
import traceback
|
| 98 |
|
| 99 |
logging.error(traceback.format_exc())
|
| 100 |
+
response = "Oops, something went wrong. Try again later!"
|
| 101 |
+
return format_response(response)
|
| 102 |
|
| 103 |
|
| 104 |
def load_embeddings(path: str) -> pd.DataFrame:
|