Spaces:
No application file
No application file
refactoring the code on the SOLID principles
Browse files- chat.py +0 -66
- index/build_index.py +1 -1
- src/chat.py +41 -0
- src/gpt_3_manager.py +32 -0
- src/index.py +40 -0
- src/prompt.py +35 -0
- utils.py → src/utils.py +0 -29
chat.py
DELETED
|
@@ -1,66 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import openai
|
| 3 |
-
from dotenv import load_dotenv
|
| 4 |
-
import jsonlines
|
| 5 |
-
from pathlib import Path
|
| 6 |
-
from utils import (
|
| 7 |
-
gpt3_embeddings,
|
| 8 |
-
gpt3_completion,
|
| 9 |
-
dot_similarity,
|
| 10 |
-
load_prompt,
|
| 11 |
-
)
|
| 12 |
-
|
| 13 |
-
load_dotenv()
|
| 14 |
-
|
| 15 |
-
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
| 16 |
-
|
| 17 |
-
openai.api_key = OPENAI_API_KEY
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def search_index(question, indexes, count=4):
|
| 21 |
-
question_embedding = gpt3_embeddings(question)
|
| 22 |
-
|
| 23 |
-
simmilarities = []
|
| 24 |
-
for index in indexes:
|
| 25 |
-
embedding = index["embedding"]
|
| 26 |
-
score = dot_similarity(question_embedding, embedding)
|
| 27 |
-
simmilarities.append({"index": index, "score": score})
|
| 28 |
-
|
| 29 |
-
sorted_similarities = sorted(
|
| 30 |
-
simmilarities, key=lambda x: x["score"], reverse=True
|
| 31 |
-
)
|
| 32 |
-
|
| 33 |
-
return sorted_similarities[:count]
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
if __name__ == "__main__":
|
| 37 |
-
with jsonlines.open(Path("./index") / "index.jsonl") as passages:
|
| 38 |
-
indexes = list(passages)
|
| 39 |
-
|
| 40 |
-
while True:
|
| 41 |
-
question = input("User >")
|
| 42 |
-
|
| 43 |
-
search_results = search_index(question=question, indexes=indexes, count=2)
|
| 44 |
-
|
| 45 |
-
answers = []
|
| 46 |
-
for result in search_results:
|
| 47 |
-
print("iterating over answering questions")
|
| 48 |
-
|
| 49 |
-
prompt = (
|
| 50 |
-
load_prompt("prompts\question_answering.txt")
|
| 51 |
-
.replace("<<PASSAGE>>", result["index"]["content"])
|
| 52 |
-
.replace("<<QUESTION>>", question)
|
| 53 |
-
)
|
| 54 |
-
|
| 55 |
-
answer = gpt3_completion(
|
| 56 |
-
prompt=prompt, max_tokens=80, model="text-curie-001"
|
| 57 |
-
)
|
| 58 |
-
answers.append(answer)
|
| 59 |
-
|
| 60 |
-
prompt = load_prompt("prompts\passage_summarization.txt").replace(
|
| 61 |
-
"<<PASSAGE>>", "\n".join(answers)
|
| 62 |
-
)
|
| 63 |
-
|
| 64 |
-
final_answer = gpt3_completion(prompt=prompt)
|
| 65 |
-
|
| 66 |
-
print(f"Bot: {final_answer}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
index/build_index.py
CHANGED
|
@@ -7,7 +7,7 @@ import openai
|
|
| 7 |
import textwrap
|
| 8 |
import jsonlines
|
| 9 |
|
| 10 |
-
from utils import gpt3_embeddings
|
| 11 |
|
| 12 |
load_dotenv()
|
| 13 |
|
|
|
|
| 7 |
import textwrap
|
| 8 |
import jsonlines
|
| 9 |
|
| 10 |
+
from src.utils import gpt3_embeddings
|
| 11 |
|
| 12 |
load_dotenv()
|
| 13 |
|
src/chat.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import openai
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
from index import IndexSearchEngine
|
| 5 |
+
from gpt_3_manager import Gpt3Manager
|
| 6 |
+
from prompt import QuestionAnsweringPrompt, PassageSummarizationPrompt
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
load_dotenv()
|
| 10 |
+
|
| 11 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
| 12 |
+
|
| 13 |
+
openai.api_key = OPENAI_API_KEY
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ChatBot:
|
| 17 |
+
def __init__(self, index_search_engine: IndexSearchEngine):
|
| 18 |
+
self.index_search_engine = index_search_engine
|
| 19 |
+
|
| 20 |
+
def ask(self, question):
|
| 21 |
+
search_result = self.index_search_engine.search(question=question)
|
| 22 |
+
|
| 23 |
+
answers = []
|
| 24 |
+
for result in search_result:
|
| 25 |
+
print("iterating over answering questions")
|
| 26 |
+
|
| 27 |
+
question_answering_prompt = QuestionAnsweringPrompt.load(
|
| 28 |
+
"prompts\question_answering.txt"
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
answer = Gpt3Manager.get_completion(
|
| 32 |
+
prompt=question_answering_prompt, max_tokens=80, model="text-curie-001"
|
| 33 |
+
)
|
| 34 |
+
answers.append(answer)
|
| 35 |
+
|
| 36 |
+
passage_summarization_prompt = PassageSummarizationPrompt.load(
|
| 37 |
+
"prompts\passage_summarization.txt"
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
final_answer = Gpt3Manager.get_completion(prompt=passage_summarization_prompt)
|
| 41 |
+
return final_answer
|
src/gpt_3_manager.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import openai
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class Gpt3Manager:
|
| 5 |
+
def __init__(self, api_key):
|
| 6 |
+
openai.api_key = api_key
|
| 7 |
+
|
| 8 |
+
def get_completion(prompt, max_tokens=128, model="text-davinci-003"):
|
| 9 |
+
response = None
|
| 10 |
+
try:
|
| 11 |
+
response = openai.Completion.create(
|
| 12 |
+
model=model,
|
| 13 |
+
prompt=prompt,
|
| 14 |
+
max_tokens=max_tokens,
|
| 15 |
+
)["choices"][0]["text"]
|
| 16 |
+
|
| 17 |
+
except Exception as err:
|
| 18 |
+
print(f"Sorry, There was a problem \n\n {err}")
|
| 19 |
+
|
| 20 |
+
return response
|
| 21 |
+
|
| 22 |
+
def get_embedding(text, model="text-similarity-ada-001"):
|
| 23 |
+
text = text.replace("\n", " ")
|
| 24 |
+
embedding = None
|
| 25 |
+
try:
|
| 26 |
+
embedding = openai.Embedding.create(input=[text], model=model)["data"][0][
|
| 27 |
+
"embedding"
|
| 28 |
+
]
|
| 29 |
+
except Exception as err:
|
| 30 |
+
print(f"Sorry, There was a problem {err}")
|
| 31 |
+
|
| 32 |
+
return embedding
|
src/index.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
import jsonlines
|
| 3 |
+
from gpt_3_manager import Gpt3Manager
|
| 4 |
+
from src.utils import dot_similarity
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Index(ABC):
|
| 8 |
+
@abstractmethod
|
| 9 |
+
def load(self, path):
|
| 10 |
+
pass
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class JsonLinesIndex(Index):
|
| 14 |
+
def __init__(self):
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
+
def load(self, path):
|
| 18 |
+
with jsonlines.open(path) as passages:
|
| 19 |
+
indexes = list(passages)
|
| 20 |
+
return indexes
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class IndexSearchEngine:
|
| 24 |
+
def __init__(self, index):
|
| 25 |
+
index = index
|
| 26 |
+
|
| 27 |
+
def search(self, question, indexes, count=4):
|
| 28 |
+
question_embedding = Gpt3Manager.get_embedding(question)
|
| 29 |
+
|
| 30 |
+
simmilarities = []
|
| 31 |
+
for index in indexes:
|
| 32 |
+
embedding = index["embedding"]
|
| 33 |
+
score = dot_similarity(question_embedding, embedding)
|
| 34 |
+
simmilarities.append({"index": index, "score": score})
|
| 35 |
+
|
| 36 |
+
sorted_similarities = sorted(
|
| 37 |
+
simmilarities, key=lambda x: x["score"], reverse=True
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
return sorted_similarities[:count]
|
src/prompt.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class Prompt(ABC):
|
| 5 |
+
def load_prompt(path):
|
| 6 |
+
with open(path) as f:
|
| 7 |
+
lines = f.readlines()
|
| 8 |
+
return "".join(lines)
|
| 9 |
+
|
| 10 |
+
@abstractmethod
|
| 11 |
+
def load(self, path):
|
| 12 |
+
pass
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class QuestionAnsweringPrompt(Prompt):
|
| 16 |
+
def __init__(self, result, question):
|
| 17 |
+
result = result
|
| 18 |
+
question = question
|
| 19 |
+
|
| 20 |
+
def load(self, path):
|
| 21 |
+
prompt = (
|
| 22 |
+
self.load_prompt(path)
|
| 23 |
+
.replace("<<PASSAGE>>", self.result["index"]["content"])
|
| 24 |
+
.replace("<<QUESTION>>", self.question)
|
| 25 |
+
)
|
| 26 |
+
return prompt
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class PassageSummarizationPrompt(Prompt):
|
| 30 |
+
def __init__(self, answers):
|
| 31 |
+
self.answers = answers
|
| 32 |
+
|
| 33 |
+
def load(self, path):
|
| 34 |
+
prompt = self.load_prompt(path).replace("<<PASSAGE>>", "\n".join(self.answers))
|
| 35 |
+
return prompt
|
utils.py → src/utils.py
RENAMED
|
@@ -1,35 +1,6 @@
|
|
| 1 |
-
import openai
|
| 2 |
import numpy as np
|
| 3 |
|
| 4 |
|
| 5 |
-
def gpt3_embeddings(text, model="text-similarity-ada-001"):
|
| 6 |
-
text = text.replace("\n", " ")
|
| 7 |
-
embedding = None
|
| 8 |
-
try:
|
| 9 |
-
embedding = openai.Embedding.create(input=[text], model=model)["data"][0][
|
| 10 |
-
"embedding"
|
| 11 |
-
]
|
| 12 |
-
except Exception as err:
|
| 13 |
-
print(f"Sorry, There was a problem {err}")
|
| 14 |
-
|
| 15 |
-
return embedding
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def gpt3_completion(prompt, max_tokens=128, model="text-davinci-003"):
|
| 19 |
-
response = None
|
| 20 |
-
try:
|
| 21 |
-
response = openai.Completion.create(
|
| 22 |
-
model=model,
|
| 23 |
-
prompt=prompt,
|
| 24 |
-
max_tokens=max_tokens,
|
| 25 |
-
)["choices"][0]["text"]
|
| 26 |
-
|
| 27 |
-
except Exception as err:
|
| 28 |
-
print(f"Sorry, There was a problem \n\n {err}")
|
| 29 |
-
|
| 30 |
-
return response
|
| 31 |
-
|
| 32 |
-
|
| 33 |
def load_prompt(path):
|
| 34 |
with open(path) as f:
|
| 35 |
lines = f.readlines()
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
|
| 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
def load_prompt(path):
|
| 5 |
with open(path) as f:
|
| 6 |
lines = f.readlines()
|