File size: 9,075 Bytes
27ec9ac
 
 
 
 
 
 
 
 
2c194ba
 
 
 
27ec9ac
 
 
 
 
 
 
 
 
2c194ba
 
27ec9ac
 
 
 
 
 
 
 
 
 
 
 
2c194ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27ec9ac
 
 
 
2c194ba
27ec9ac
 
 
 
 
 
 
 
2c194ba
 
27ec9ac
2c194ba
 
27ec9ac
 
2c194ba
 
 
 
 
27ec9ac
 
2c194ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27ec9ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c194ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
from langchain.document_loaders import PyPDFLoader, PyPDFDirectoryLoader
from langchain.indexes import VectorstoreIndexCreator
from langchain.llms import AzureOpenAI, OpenAI
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQA, ConversationalRetrievalChain, RetrievalQAWithSourcesChain
from langchain.chains.question_answering import load_qa_chain
from langchain.memory import ConversationBufferMemory
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate, ChatPromptTemplate
from langchain.output_parsers import CommaSeparatedListOutputParser
# from langchain.chains.summarize import load_summarize_chain 

from langchain.chat_models import AzureChatOpenAI


import os
import openai
os.environ['CWD'] = os.getcwd()

# for testing
# import src.constants as constants
import constants 
os.environ['OPENAI_API_KEY'] = constants.AZURE_OPENAI_KEY_FR
os.environ['OPENAI_API_BASE'] = constants.AZURE_OPENAI_ENDPOINT_FR
os.environ['OPENAI_API_VERSION'] = "2023-05-15"
os.environ['OPENAI_API_TYPE'] = "azure"
# openai.api_type = "azure"
# openai.api_base = constants.AZURE_OPENAI_ENDPOINT_FR
# openai.api_version = "2023-05-15"
openai.api_key = constants.OPEN_AI_KEY


import os
from typing import Optional
import hardcoded_data 

class TLDR():
    def __init__(self):
        self.prompt_template_general = constants.prompt_template_general
        self.prompt_template_scientific = constants.prompt_template_scientific
        
        self.llm = AzureChatOpenAI(deployment_name= constants.AZURE_ENGINE_NAME_FR, temperature=0)

    def load_text(self, category_list):
        self.documents = []
        for articles_list in category_list:
            category_name = articles_list['category']
            formatted_list = []
            for article in articles_list['articles']:
                title = article["title"]
                abstract = article["abstract"]
                formatted_article = f"Title: {title}\n\nAbstract: {abstract}"
                formatted_list.append(formatted_article)
            self.documents.append({'category': category_name, 'articles': formatted_list})


    def parse_function(self, input_string):
        # Split the input string into individual entries using the "Title:" pattern
        entries = input_string.split("Title: ")[1:]

        # Initialize an empty list to store the dictionaries
        result_list = []

        # Loop through each entry and extract the title and TLDR
        for entry in entries:
            title, rest = entry.split("\nTldr: ", 1)

            # Find the end of the TLDR by looking for the next "Title:" or reaching the end of the entry
            next_title_index = rest.find("Title: ")
            if next_title_index == -1:
                next_title_index = len(rest)

            # Extract the TLDR
            tldr = rest[:next_title_index].strip()
            entry_dict = {"title": title.strip(), "tldr": tldr}
            result_list.append(entry_dict)

        return result_list



    def summarize(self, target = "general"):

        if target == "general":
            prompt_template = self.prompt_template_general
        elif target == "scientific":
            prompt_template = self.prompt_template_scientific
        else:
            raise ValueError("Error: TLDR.summarize() target must be 'general' or 'scientific'")
        
        prompt = PromptTemplate(template= prompt_template, input_variables= ["field", "context"])

        result = []
        for category in self.documents:
            category_name = category['category']
            joined_articles = "\n\n".join(category['articles'])
            chain = LLMChain(llm=self.llm, prompt=prompt)
            output = chain.run(context=joined_articles, field = category_name)
            result.append({"category": category_name ,"articles": self.parse_function(output)})

        return result




class PDFEmbeddings():
    def __init__(self, path: Optional[str] = None):
        self.path = path or os.path.join(os.environ['CWD'], 'archive')
        self.text_splitter = CharacterTextSplitter(separator="\n", chunk_size=2000, chunk_overlap=200)
        self.embeddings = OpenAIEmbeddings(deployment= constants.AZURE_ENGINE_NAME_US, chunk_size=1,
                                           openai_api_key= constants.AZURE_OPENAI_KEY_US,
                                           openai_api_base= constants.AZURE_OPENAI_ENDPOINT_US,
                                           openai_api_version= "2023-05-15",
                                           openai_api_type= "azure",)
        self.vectorstore = Chroma(persist_directory=constants.persistent_dir, embedding_function=self.embeddings)
        self.retriever = self.vectorstore.as_retriever(search_type = "similarity", search_kwags= {"k": 5})
        self.memory = ConversationBufferMemory(memory_key='pdf_memory', return_messages=True)
        self.documents = self.load_documents()  # Load documents during initialization
        # self.process_documents()  # Process documents during initialization (?)

    def load_documents(self):
        # Single responsibility: load the documents
        loader = PyPDFDirectoryLoader(self.path)
        documents = loader.load()
        return documents

    def process_documents(self):
        # Single responsibility: create the embeddings of the document chunks
        chunks = self.text_splitter.split_documents(self.documents)
        self.vectorstore.add_documents(chunks)

    def semantic_search(self, num_queries):
        # Single responsibility: perform a semantic search
        document_sources = set([doc.metadata['source'] for doc in self.documents])
        unique_chunks = set()
        queries = list(constants.similarity_search_queries.values())[:num_queries]

        for source in document_sources:
            for query in queries:
                results = self.vectorstore.similarity_search(query, k=2, filter={'source': source})
                for chunk in results:
                    chunk_str = str(chunk)
                    if chunk_str not in unique_chunks:
                        unique_chunks.add(chunk_str)

        return unique_chunks

    def extract_queries_from_documents(self, num_similarity_search_queries= 3):
        # Perform semantic search
        unique_chunks = self.semantic_search(num_similarity_search_queries)

        output_parser = CommaSeparatedListOutputParser()
        format_instructions = output_parser.get_format_instructions()
        pdf_template = constants.pdf_template
        prompt_template = ChatPromptTemplate.from_template(template= pdf_template)

        chain = LLMChain(
            llm= AzureChatOpenAI(deployment_name= constants.AZURE_ENGINE_NAME_FR),
            prompt=prompt_template,
        )
        # Changed 'context_docs' to 'unique_chunks' as that's what's available in this method
        output = chain.run(context=unique_chunks, format_instructions=format_instructions)

        return output_parser.parse(output)



    def search(self, query: str, chain_type: str = "stuff"):
        chain = RetrievalQA.from_chain_type(llm= AzureChatOpenAI(deployment_name= constants.AZURE_ENGINE_NAME_FR, temperature=0),
                                            retriever= self.retriever, chain_type= chain_type, return_source_documents= True)
        result = chain({"query": query})
        return result

    def conversational_search(self, query: str, chain_type: str = "stuff"):
        chain = ConversationalRetrievalChain.from_llm(llm= AzureChatOpenAI(deployment_name= constants.AZURE_ENGINE_NAME_FR),
                                                      retriever= self.retriever, memory= self.memory, chain_type= chain_type)
        result = chain({"question": query})
        return result['answer']

    def load_and_run_chain(self, query: str, chain_type: str = "stuff"):
        chain = load_qa_chain(llm= AzureChatOpenAI(deployment_name= constants.AZURE_ENGINE_NAME_FR), chain_type= chain_type)
        return chain.run(input_documents = self.retriever, question = query)

if __name__ == '__main__':

    ################ USE CASE 1: PDF EMBEDDINGS ################
    # pdf_embed = PDFEmbeddings()
    # # pdf_embed.process_documents() # This takes a while, so we only do it once, this does the embedding
    # result = pdf_embed.extract_queries_from_documents(num_similarity_search_queries=5)
    # print("type of result: ", type(result))
    # for i in result:
    #     print(i)
    

    ################ USE CASE 2: TLDR Summarize ################
    tldr = TLDR()
    tldr.load_text(hardcoded_data.articles)
    results = tldr.summarize(target= "general") # target can be "general" or "scientific"
    for result in results:
        print("\n\ncategory: ", result['category'])
        for article in result['articles']:
            print("\n\nTitle: ", article['title'])
            print("\n")
            print("Tldr: ", article['tldr'])