avysotsky commited on
Commit
075d420
·
0 Parent(s):

initial commit

Browse files
app.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from lib.utils import ask
4
+
5
+
6
+ def ask_question(question, history):
7
+ history = history or []
8
+ answer = ask(question)
9
+
10
+ history.append((question, answer))
11
+ return history, history
12
+
13
+
14
+ demo = gr.Interface(fn=ask_question,
15
+ title="Ask Lethain a question",
16
+ inputs=["text", "state"],
17
+ outputs=["chatbot", "state"], allow_flagging="never")
18
+
19
+ demo.launch()
data/lethain.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9cb0c043cb792645b49862c59203c7f5d78fabd2dd91f22cee17cf1dfe53da18
3
+ size 163150654
lib/__pycache__/utils.cpython-39.pyc ADDED
Binary file (2.32 kB). View file
 
lib/utils.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import openai
5
+ import tiktoken
6
+ import pandas as pd
7
+ from openai.embeddings_utils import get_embedding, cosine_similarity
8
+
9
+ encoding_name = "p50k_base"
10
+
11
+ encoding = tiktoken.get_encoding(encoding_name)
12
+
13
+ embedding_model = "text-embedding-ada-002"
14
+ openai.api_key = "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"
15
+
16
+ # read from current directory
17
+ df = pd.read_pickle(Path(__file__).resolve().parent.__str__() + "/../data/lethain.pkl")
18
+
19
+
20
+ def search_reviews(df, query):
21
+ query_embedding = get_embedding(
22
+ query,
23
+ engine="text-embedding-ada-002"
24
+ )
25
+ df["similarity"] = df.embeddings.apply(lambda x: cosine_similarity(x, query_embedding))
26
+
27
+ results = (
28
+ df.sort_values("similarity", ascending=False)
29
+
30
+ )
31
+ return results
32
+
33
+
34
+ def construct_prompt(question: str, df: pd.DataFrame) -> str:
35
+ MAX_SECTION_LEN = 500
36
+ SEPARATOR = "\n* "
37
+
38
+ separator_len = len(encoding.encode(SEPARATOR))
39
+
40
+ f"Context separator contains {separator_len} tokens"
41
+
42
+ """
43
+ Fetch relevant
44
+ """
45
+ result = search_reviews(df, question)
46
+
47
+ chosen_sections = []
48
+ chosen_sections_len = 0
49
+ chosen_sections_indexes = []
50
+
51
+ for section_index, row in result.iterrows():
52
+ # Add contexts until we run out of space.
53
+
54
+ tokens_num = len(encoding.encode(row.content))
55
+ chosen_sections_len += tokens_num
56
+ if chosen_sections_len > MAX_SECTION_LEN:
57
+ break
58
+
59
+ chosen_sections.append(SEPARATOR + row.content.replace("\n", " "))
60
+ chosen_sections_indexes.append(str(section_index))
61
+
62
+ # Useful diagnostic information
63
+ print(f"Selected {len(chosen_sections)} document sections:")
64
+ print("\n".join(chosen_sections_indexes))
65
+
66
+ header = """You name is Will Larson, you are CTO at Calm and a blogger about engineering leadership. Answer the question as truthfully as possible using the provided context, and if the answer is not contained within the text below, say "I don't know."\n\nContext:\n"""
67
+
68
+ return header + "".join(chosen_sections) + "\n\n Q: " + question + "\n A:"
69
+
70
+
71
+ def ask(question):
72
+ prompt = construct_prompt(question, df)
73
+ result = openai.Completion.create(
74
+ prompt=prompt,
75
+ temperature=0,
76
+ max_tokens=300,
77
+ model="text-davinci-003"
78
+ )
79
+
80
+ return result['choices'][0]['text']
requirements.txt ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==22.1.0
2
+ aiohttp==3.8.3
3
+ aiosignal==1.3.1
4
+ altair==4.2.2
5
+ anyio==3.6.2
6
+ async-timeout==4.0.2
7
+ attrs==22.2.0
8
+ blobfile==2.0.1
9
+ certifi==2022.12.7
10
+ charset-normalizer==2.1.1
11
+ click==8.1.3
12
+ contourpy==1.0.7
13
+ cycler==0.11.0
14
+ entrypoints==0.4
15
+ fastapi==0.89.1
16
+ ffmpy==0.3.0
17
+ filelock==3.9.0
18
+ fonttools==4.38.0
19
+ frozenlist==1.3.3
20
+ fsspec==2023.1.0
21
+ gradio==3.16.2
22
+ h11==0.14.0
23
+ httpcore==0.16.3
24
+ httpx==0.23.3
25
+ idna==3.4
26
+ Jinja2==3.1.2
27
+ joblib==1.2.0
28
+ jsonschema==4.17.3
29
+ kiwisolver==1.4.4
30
+ linkify-it-py==1.0.3
31
+ lxml==4.9.2
32
+ markdown-it-py==2.1.0
33
+ MarkupSafe==2.1.2
34
+ matplotlib==3.6.3
35
+ mdit-py-plugins==0.3.3
36
+ mdurl==0.1.2
37
+ multidict==6.0.4
38
+ numpy==1.24.1
39
+ openai==0.26.4
40
+ orjson==3.8.5
41
+ packaging==23.0
42
+ pandas==1.5.3
43
+ Pillow==9.4.0
44
+ plotly==5.13.0
45
+ pycryptodome==3.16.0
46
+ pycryptodomex==3.16.0
47
+ pydantic==1.10.4
48
+ pydub==0.25.1
49
+ pyparsing==3.0.9
50
+ pyrsistent==0.19.3
51
+ python-dateutil==2.8.2
52
+ python-multipart==0.0.5
53
+ pytz==2022.7.1
54
+ PyYAML==6.0
55
+ regex==2022.10.31
56
+ requests==2.28.2
57
+ rfc3986==1.5.0
58
+ scikit-learn==1.2.1
59
+ scipy==1.10.0
60
+ six==1.16.0
61
+ sklearn==0.0.post1
62
+ sniffio==1.3.0
63
+ starlette==0.22.0
64
+ tenacity==8.1.0
65
+ threadpoolctl==3.1.0
66
+ tiktoken==0.1.2
67
+ toolz==0.12.0
68
+ tqdm==4.64.1
69
+ typing_extensions==4.4.0
70
+ uc-micro-py==1.0.1
71
+ urllib3==1.26.14
72
+ uvicorn==0.20.0
73
+ websockets==10.4
74
+ yarl==1.8.2