Spaces:
No application file
No application file
adding final details
Browse files- requirements.txt +2 -1
- src/__init__.py +0 -0
- src/chat.py +22 -11
- src/index.py +1 -2
- src/main.py +29 -0
- src/prompt.py +2 -2
- src/tests/__init__.py +0 -1
- src/tests/chat_test.py +16 -13
requirements.txt
CHANGED
|
@@ -3,4 +3,5 @@ textwrap3
|
|
| 3 |
openai
|
| 4 |
python-dotenv
|
| 5 |
jsonlines
|
| 6 |
-
pytest
|
|
|
|
|
|
| 3 |
openai
|
| 4 |
python-dotenv
|
| 5 |
jsonlines
|
| 6 |
+
pytest
|
| 7 |
+
numpy
|
src/__init__.py
DELETED
|
File without changes
|
src/chat.py
CHANGED
|
@@ -3,7 +3,8 @@ import openai
|
|
| 3 |
from dotenv import load_dotenv
|
| 4 |
from index import IndexSearchEngine
|
| 5 |
from gpt_3_manager import Gpt3Manager
|
| 6 |
-
from prompt import QuestionAnsweringPrompt, PassageSummarizationPrompt
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
load_dotenv()
|
|
@@ -14,28 +15,38 @@ openai.api_key = OPENAI_API_KEY
|
|
| 14 |
|
| 15 |
|
| 16 |
class ChatBot:
|
| 17 |
-
def __init__(
|
|
|
|
|
|
|
| 18 |
self.index_search_engine = index_search_engine
|
|
|
|
|
|
|
| 19 |
|
| 20 |
def ask(self, question):
|
| 21 |
-
search_result = self.index_search_engine.search(question=question)
|
| 22 |
|
| 23 |
answers = []
|
| 24 |
for result in search_result:
|
| 25 |
print("iterating over answering questions")
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
| 29 |
)
|
| 30 |
|
| 31 |
-
answer =
|
| 32 |
-
prompt=
|
| 33 |
)
|
| 34 |
answers.append(answer)
|
| 35 |
|
| 36 |
-
passage_summarization_prompt = PassageSummarizationPrompt
|
| 37 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
)
|
| 39 |
|
| 40 |
-
final_answer =
|
| 41 |
return final_answer
|
|
|
|
| 3 |
from dotenv import load_dotenv
|
| 4 |
from index import IndexSearchEngine
|
| 5 |
from gpt_3_manager import Gpt3Manager
|
| 6 |
+
from prompt import QuestionAnsweringPrompt, PassageSummarizationPrompt, TextPromptLoader
|
| 7 |
+
from pathlib import Path
|
| 8 |
|
| 9 |
|
| 10 |
load_dotenv()
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
class ChatBot:
|
| 18 |
+
def __init__(
|
| 19 |
+
self, index_search_engine: IndexSearchEngine, prompt_loader, gpt_manager
|
| 20 |
+
):
|
| 21 |
self.index_search_engine = index_search_engine
|
| 22 |
+
self.prompet_loader = prompt_loader
|
| 23 |
+
self.gpt_manager = gpt_manager
|
| 24 |
|
| 25 |
def ask(self, question):
|
| 26 |
+
search_result = self.index_search_engine.search(question=question, count=2)
|
| 27 |
|
| 28 |
answers = []
|
| 29 |
for result in search_result:
|
| 30 |
print("iterating over answering questions")
|
| 31 |
+
question_answering_prompt = QuestionAnsweringPrompt(
|
| 32 |
+
passage=result, question=question, prompt_loader=self.prompet_loader
|
| 33 |
+
)
|
| 34 |
+
prompt = question_answering_prompt.load(
|
| 35 |
+
Path("prompts") / "question_answering.txt"
|
| 36 |
)
|
| 37 |
|
| 38 |
+
answer = self.gpt_manager.get_completion(
|
| 39 |
+
prompt=prompt, max_tokens=80, model="text-curie-001"
|
| 40 |
)
|
| 41 |
answers.append(answer)
|
| 42 |
|
| 43 |
+
passage_summarization_prompt = PassageSummarizationPrompt(
|
| 44 |
+
"\n".join(answers), self.prompet_loader
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
prompt = passage_summarization_prompt.load(
|
| 48 |
+
Path("prompts") / "passage_summarization.txt"
|
| 49 |
)
|
| 50 |
|
| 51 |
+
final_answer = self.gpt_manager.get_completion(prompt=prompt)
|
| 52 |
return final_answer
|
src/index.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
from abc import ABC, abstractmethod
|
| 2 |
import jsonlines
|
| 3 |
-
from gpt_3_manager import Gpt3Manager
|
| 4 |
from utils import dot_similarity
|
| 5 |
|
| 6 |
|
|
@@ -35,4 +34,4 @@ class IndexSearchEngine:
|
|
| 35 |
simmilarities, key=lambda x: x["score"], reverse=True
|
| 36 |
)
|
| 37 |
|
| 38 |
-
return sorted_similarities[:count]
|
|
|
|
| 1 |
from abc import ABC, abstractmethod
|
| 2 |
import jsonlines
|
|
|
|
| 3 |
from utils import dot_similarity
|
| 4 |
|
| 5 |
|
|
|
|
| 34 |
simmilarities, key=lambda x: x["score"], reverse=True
|
| 35 |
)
|
| 36 |
|
| 37 |
+
return [result["index"]["content"] for result in sorted_similarities[:count]]
|
src/main.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
from index import IndexSearchEngine
|
| 5 |
+
from gpt_3_manager import Gpt3Manager
|
| 6 |
+
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
from chat import ChatBot
|
| 9 |
+
from index import JsonLinesIndex
|
| 10 |
+
|
| 11 |
+
from prompt import TextPromptLoader
|
| 12 |
+
|
| 13 |
+
load_dotenv()
|
| 14 |
+
|
| 15 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
path = Path("index") / "index.jsonl"
|
| 19 |
+
|
| 20 |
+
index = JsonLinesIndex()
|
| 21 |
+
loaded = index.load(path)
|
| 22 |
+
gpt_manager = Gpt3Manager(api_key=OPENAI_API_KEY)
|
| 23 |
+
|
| 24 |
+
engine = IndexSearchEngine(loaded, gpt_manager=gpt_manager)
|
| 25 |
+
loader = TextPromptLoader()
|
| 26 |
+
chatbot = ChatBot(engine, prompt_loader=loader, gpt_manager=gpt_manager)
|
| 27 |
+
|
| 28 |
+
answer = chatbot.ask("What does the twitter terms of service does")
|
| 29 |
+
print(answer)
|
src/prompt.py
CHANGED
|
@@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
|
|
| 3 |
# Prompt Loaders
|
| 4 |
class PromptLoader(ABC):
|
| 5 |
@abstractmethod
|
| 6 |
-
def load_prompt():
|
| 7 |
pass
|
| 8 |
|
| 9 |
|
|
@@ -50,7 +50,7 @@ class PassageSummarizationPrompt(Prompt):
|
|
| 50 |
super().__init__(prompt_loader=prompt_loader)
|
| 51 |
self.passage = passage
|
| 52 |
|
| 53 |
-
# prompt = self.load_prompt(path).replace("<<PASSAGE>>",
|
| 54 |
|
| 55 |
def load(self, path):
|
| 56 |
prompt = self.load_prompt(path).replace("<<PASSAGE>>", self.passage)
|
|
|
|
| 3 |
# Prompt Loaders
|
| 4 |
class PromptLoader(ABC):
|
| 5 |
@abstractmethod
|
| 6 |
+
def load_prompt(self, path):
|
| 7 |
pass
|
| 8 |
|
| 9 |
|
|
|
|
| 50 |
super().__init__(prompt_loader=prompt_loader)
|
| 51 |
self.passage = passage
|
| 52 |
|
| 53 |
+
# prompt = self.load_prompt(path).replace("<<PASSAGE>>", )
|
| 54 |
|
| 55 |
def load(self, path):
|
| 56 |
prompt = self.load_prompt(path).replace("<<PASSAGE>>", self.passage)
|
src/tests/__init__.py
CHANGED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
|
|
|
|
|
|
src/tests/chat_test.py
CHANGED
|
@@ -1,28 +1,31 @@
|
|
| 1 |
import os
|
| 2 |
from pathlib import Path
|
|
|
|
| 3 |
from index import IndexSearchEngine
|
| 4 |
from gpt_3_manager import Gpt3Manager
|
|
|
|
| 5 |
from dotenv import load_dotenv
|
| 6 |
from chat import ChatBot
|
| 7 |
from index import JsonLinesIndex
|
| 8 |
|
| 9 |
-
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
|
| 12 |
|
| 13 |
|
| 14 |
-
|
| 15 |
-
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
# engine = IndexSearchEngine(loaded, gpt_manager=gpt_manager)
|
| 21 |
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
# # assert 0 == 0
|
| 26 |
|
|
|
|
| 27 |
|
| 28 |
-
|
|
|
|
| 1 |
import os
|
| 2 |
from pathlib import Path
|
| 3 |
+
|
| 4 |
from index import IndexSearchEngine
|
| 5 |
from gpt_3_manager import Gpt3Manager
|
| 6 |
+
|
| 7 |
from dotenv import load_dotenv
|
| 8 |
from chat import ChatBot
|
| 9 |
from index import JsonLinesIndex
|
| 10 |
|
| 11 |
+
from prompt import TextPromptLoader
|
| 12 |
+
|
| 13 |
+
load_dotenv()
|
| 14 |
|
| 15 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
| 16 |
|
| 17 |
|
| 18 |
+
def test_chatbot():
|
| 19 |
+
path = Path("index") / "index.jsonl"
|
| 20 |
|
| 21 |
+
index = JsonLinesIndex()
|
| 22 |
+
loaded = index.load(path)
|
| 23 |
+
gpt_manager = Gpt3Manager(api_key=OPENAI_API_KEY)
|
|
|
|
| 24 |
|
| 25 |
+
engine = IndexSearchEngine(loaded, gpt_manager=gpt_manager)
|
| 26 |
+
loader = TextPromptLoader()
|
| 27 |
+
chatbot = ChatBot(engine, prompt_loader=loader, gpt_manager=gpt_manager)
|
|
|
|
| 28 |
|
| 29 |
+
answer = chatbot.ask("What does the twitter terms of service does")
|
| 30 |
|
| 31 |
+
assert answer != None
|