File size: 5,538 Bytes
844cee8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os
import openai
import time
import numpy as np
import time
import pandas as pd

GPT_MODEL_AUGMENT = "gpt-3.5-turbo-16k"
GPT_MODEL_ANSWER = "gpt-3.5-turbo-16k"


def cosine_similarity(a, b):
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

def get_embedding(text, model="text-embedding-ada-002"):
    try:
        text = text.replace("\n", " ")
    except:
        None
    try:
        return openai.Embedding.create(input = [text], model=model)['data'][0]['embedding']
    except:
        time.sleep(2)

def augment_query(query):

    SYS_PROMPT = """
        On [current date: 19 July] Generate a JSON response with the following structure:

        {
        "timestamps": # Relevant timestamps in which to get data to answer the query,
        "query": # Repeat the user's query,
        }
        Allowed timestamps:
        ['2018-07-18', '2018-07-19', '2018-07-08', '2018-07-09', '2018-07-10', '2018-07-11', '2018-07-12', '2018-07-13', '2018-07-14', '2018-07-15', '2018-07-16', '2018-07-17']

        Ensure the output is always in JSON format and never provide any other response.
        """
    response = openai.chat.completions.create(
                                            model=GPT_MODEL_AUGMENT,
                                            messages=
                                            [
                                                {
                                                "role": "system",
                                                "content": SYS_PROMPT
                                                },
                                                {
                                                "role": "user",
                                                "content": query
                                                }
                                            ],
                                            temperature=1,
                                            max_tokens=1000,
                                            top_p=1,
                                            frequency_penalty=0,
                                            presence_penalty=0,
                                            ).choices[0].message.content
    return response

def semantic_search(df_loc, query,timestamp, nb_programs_to_display=15):
    timestamp = str(timestamp).strip()
    embedding = get_embedding(query, model='text-embedding-ada-002')
    filtered_df = df_loc[df_loc["timestamp"]==timestamp].drop(columns=["url"])
    def wrap_cos(x,y):
        try:
            res = cosine_similarity(x,y)
        except:
            res = 0
        return res
    filtered_df['similarity']  = filtered_df.embedding.apply(lambda x: wrap_cos(x, embedding))

    results = filtered_df.sort_values('similarity', ascending=False).head(nb_programs_to_display)
    return results

def get_relevant_documents(df, query, nb_programs_to_display=15):

    query = eval(query)
    all_retrieved = []
    for timestamp in query["timestamps"]:
        all_retrieved.append({
            "timestamp" : timestamp,
            "tweets" : semantic_search(df, query["query"],timestamp, nb_programs_to_display=nb_programs_to_display)
        })

    return all_retrieved

def get_final_answer(relevant_documents, query):
    context = ""
    for relevant_timestamp in relevant_documents: 
        list_tweets = relevant_timestamp["tweets"]
        context += "\nTimestamp: " + relevant_timestamp["timestamp"] + "\nList of tweets:\n" + str((list_tweets["text"] + "   --- Tweeted by: @" +list_tweets["source"] +  " \n").to_list()) + "\n---"


    USER_PROMPT = f"""
    "We have provided context information below. 
    ---------------------
    {context}
    "\n---------------------\n"
    Given this information, please answer the question: {query}
    """
    response = openai.chat.completions.create(
                                                model=GPT_MODEL_ANSWER,
                                                messages=[
                                                    {
                                                    "role": "user",
                                                    "content": USER_PROMPT
                                                    }
                                                ],

                                                temperature=1,
                                                max_tokens=1000,
                                                top_p=1,
                                                frequency_penalty=0,
                                                presence_penalty=0,
                                                ).choices[0].message.content
    return response

def get_answer(query, df, api_key):
    """This traditional RAG approach has been implemented without using deidcated libraries and include different steps.
    It starts by augmenting the query and then perform a semantic search on the augmented query. Finally it combines the augmented query and the retrieved documents into an answer. 

    Args:
        query (String): Query given by the user
        df (pd.DataFrame()): corpus with embeddings
        api_key (String): OpenAI API key

    Returns:
        String: Answer to the original query
    """
    openai.api_key = api_key
    augmented_query = augment_query(query)
    relevant_documents = get_relevant_documents(df, augmented_query,nb_programs_to_display=10)
    response = get_final_answer(relevant_documents, augmented_query,)
    return response