Spaces:
Runtime error
Runtime error
| import json | |
| import os | |
| import re | |
| import openai | |
| from langchain.prompts import PromptTemplate | |
| from config import TIMEOUT_STREAM, HISTORY_DIR | |
| from vector_db import upload_file | |
| from callback import StreamingGradioCallbackHandler | |
| from queue import SimpleQueue, Empty, Queue | |
| from threading import Thread | |
| from utils import add_source_numbers, add_details, web_citation, get_history_names | |
| from chains.custom_chain import CustomConversationalRetrievalChain | |
| from langchain.chains import LLMChain | |
| from chains.azure_openai import CustomAzureOpenAI | |
| from config import OPENAI_API_TYPE, OPENAI_API_VERSION, OPENAI_API_KEY, OPENAI_API_BASE, API_KEY, \ | |
| DEPLOYMENT_ID, MODEL_ID | |
| from cosmos_db import upsert_item, read_item, delete_items, query_items | |
| class OpenAIModel: | |
| def __init__( | |
| self, | |
| llm_model_name, | |
| condense_model_name, | |
| prompt_template="", | |
| temperature=0.0, | |
| top_p=1.0, | |
| n_choices=1, | |
| stop=None, | |
| presence_penalty=0, | |
| frequency_penalty=0, | |
| user=None | |
| ): | |
| self.llm_model_name = llm_model_name | |
| self.condense_model_name = condense_model_name | |
| self.prompt_template = prompt_template | |
| self.temperature = temperature | |
| self.top_p = top_p | |
| self.n_choices = n_choices | |
| self.stop = stop | |
| self.presence_penalty = presence_penalty | |
| self.frequency_penalty = frequency_penalty | |
| self.history = [] | |
| self.user_identifier = user | |
| def set_user_identifier(self, new_user_identifier): | |
| self.user_identifier = new_user_identifier | |
| def format_prompt(self, qa_prompt_template, condense_prompt_template): | |
| # Prompt template langchain | |
| qa_prompt = PromptTemplate(template=qa_prompt_template, input_variables=["question", "chat_history", "context"]) | |
| condense_prompt = PromptTemplate(template=condense_prompt_template, | |
| input_variables=["question", "chat_history"]) | |
| return qa_prompt, condense_prompt | |
| def memory(self, inputs, outputs, last_k=3): | |
| # last_k: top k last conversation | |
| if len(self.history) >= last_k: | |
| self.history.pop(0) | |
| self.history.extend([(inputs, outputs)]) | |
| def reset_conversation(self): | |
| self.history = [] | |
| return [] | |
| def delete_first_conversation(self): | |
| if self.history: | |
| self.history.pop(0) | |
| def delete_last_conversation(self): | |
| if len(self.history) > 0: | |
| self.history.pop() | |
| def save_history(self, chatbot, file_name): | |
| message = upsert_item(self.user_identifier, file_name, self.history, chatbot) | |
| return message | |
| def load_history(self, file_name): | |
| items = read_item(self.user_identifier, file_name) | |
| return items['id'], items['chatbot'] | |
| def delete_history(self, file_name): | |
| message = delete_items(self.user_identifier, file_name) | |
| return message, get_history_names(False, self.user_identifier), [] | |
| def audio_response(self, audio): | |
| media_file = open(audio, 'rb') | |
| response = openai.Audio.transcribe( | |
| api_key=API_KEY, | |
| model=MODEL_ID, | |
| file=media_file | |
| ) | |
| return response["text"], None | |
| def inference(self, inputs, chatbot, streaming=False, upload_files_btn=False, custom_websearch=False, | |
| local_db=False, | |
| **kwargs): | |
| if upload_files_btn or local_db: | |
| status_text = "Indexing files to vector database" | |
| yield chatbot, status_text | |
| vectorstore = upload_file(upload_files_btn) | |
| qa_prompt, condense_prompt = self.format_prompt(**kwargs) | |
| job_done = object() # signals the processing is done | |
| q = SimpleQueue() | |
| if streaming: | |
| timeout = TIMEOUT_STREAM | |
| streaming_callback = [StreamingGradioCallbackHandler(q)] | |
| # Define llm model | |
| llm = CustomAzureOpenAI(deployment_name=DEPLOYMENT_ID, | |
| openai_api_type=OPENAI_API_TYPE, | |
| openai_api_base=OPENAI_API_BASE, | |
| openai_api_version=OPENAI_API_VERSION, | |
| openai_api_key=OPENAI_API_KEY, | |
| temperature=self.temperature, | |
| model_kwargs={"top_p": self.top_p}, | |
| streaming=streaming, \ | |
| callbacks=streaming_callback, | |
| request_timeout=timeout) | |
| condense_llm = CustomAzureOpenAI(deployment_name=self.condense_model_name, | |
| openai_api_type=OPENAI_API_TYPE, | |
| openai_api_base=OPENAI_API_BASE, | |
| openai_api_version=OPENAI_API_VERSION, | |
| openai_api_key=OPENAI_API_KEY, | |
| temperature=self.temperature) | |
| status_text = "Request URL: " + OPENAI_API_BASE | |
| yield chatbot, status_text | |
| # Create a function to call - this will run in a thread | |
| # Create a Queue object | |
| response_queue = SimpleQueue() | |
| def task(): | |
| # Conversation + RetrivalChain | |
| qa = CustomConversationalRetrievalChain.from_llm(llm, vectorstore.as_retriever( | |
| search_type="similarity_score_threshold", search_kwargs={"score_threshold": 0.75}), | |
| condense_question_llm=condense_llm, verbose=True, | |
| condense_question_prompt=condense_prompt, | |
| combine_docs_chain_kwargs={"prompt": qa_prompt}, | |
| return_source_documents=True) | |
| # query with input and chat history | |
| response = qa({"question": inputs, "chat_history": self.history}) | |
| response_queue.put(response) | |
| q.put(job_done) | |
| thread = Thread(target=task) | |
| thread.start() | |
| chatbot.append((inputs, "")) | |
| content = "" | |
| while True: | |
| try: | |
| next_token = q.get(block=True) | |
| if next_token is job_done: | |
| break | |
| content += next_token | |
| chatbot[-1] = (chatbot[-1][0], content) | |
| yield chatbot, status_text | |
| except Empty: | |
| continue | |
| # add citation info to response | |
| response = response_queue.get() | |
| relevant_docs = response["source_documents"] | |
| if len(relevant_docs) == 0: | |
| display_append = "" | |
| else: | |
| if upload_files_btn: | |
| reference_results = [d.page_content for d in relevant_docs] | |
| reference_sources = [d.metadata["source"] for d in relevant_docs] | |
| display_append = add_details(reference_results, reference_sources) | |
| display_append = '<div class = "source-a">' + "\n".join(display_append) + '</div>' | |
| else: | |
| display_append = [] | |
| for idx, d in enumerate(relevant_docs): | |
| link = d.metadata["source"] | |
| title = d.page_content.split("\n")[0] | |
| # Remove non word characters and blank space before title | |
| title = re.sub(r"[^\w\s]", "", title[:4]).strip() | |
| display_append.append( | |
| f'<a href=\"{link}\" target=\"_blank\">[{idx + 1}] {title}</a>' | |
| ) | |
| display_append = '<div class = "source-a">' + "\n".join(display_append) + '</div>' | |
| chatbot[-1] = (chatbot[-1][0], content + display_append) | |
| yield chatbot, status_text | |
| self.memory(inputs, content) | |
| # self.auto_save_history(chatbot) | |
| thread.join() | |
| else: | |
| import requests | |
| from langchain.utilities.google_search import GoogleSearchAPIWrapper | |
| from chains.web_search import GoogleWebSearch | |
| from config import GOOGLE_API_KEY, GOOGLE_CSE_ID | |
| top_k = 4 | |
| if custom_websearch: | |
| status_text = "Retrieving information from website FPTSoftware.com" | |
| yield chatbot, status_text | |
| params = { | |
| "q": inputs, | |
| "v": "\{539C9DC1-663A-418D-82A4-662D34EE34BC\}", | |
| "p": 10, | |
| "l": "en", | |
| "s": "{EACE8DB5-668F-4357-9782-405070D28D11}", | |
| "itemid": "\{91F4101E-B1F3-4905-A832-96F703D3FBB1\}", | |
| } | |
| req = requests.get( | |
| "https://fptsoftware.com//sxa/search/results/?", | |
| params=params | |
| ) | |
| res = json.loads(req.text) | |
| results = [] | |
| for r in res["Results"][:top_k]: | |
| link = "https://fptsoftware.com" + r["Url"] | |
| results.append({"link": link}) | |
| reference_results, display_append = web_citation(inputs, results, True) | |
| reference_results = add_source_numbers(reference_results) | |
| display_append = '<div class = "source-a">' + "\n".join(display_append) + '</div>' | |
| status_text = "Request URL: " + OPENAI_API_BASE | |
| yield chatbot, status_text | |
| chatbot.append((inputs, "")) | |
| web_search = GoogleWebSearch() | |
| ai_response = web_search.predict(context="\n\n".join(reference_results), question=inputs, | |
| chat_history=self.history) | |
| chatbot[-1] = (chatbot[-1][0], ai_response + display_append) | |
| self.memory(inputs, ai_response) | |
| # self.auto_save_history(chatbot) | |
| yield chatbot, status_text | |
| else: | |
| from chains.decision_maker import DecisionMaker | |
| from chains.simple_chain import SimpleChain | |
| decision_maker = DecisionMaker() | |
| simple_chain = SimpleChain() | |
| decision = decision_maker.predict(question=inputs) | |
| if "LLM Model" in decision: | |
| status_text = "Request URL: " + OPENAI_API_BASE | |
| yield chatbot, status_text | |
| chatbot.append((inputs, "")) | |
| ai_response = simple_chain.predict(question=inputs) | |
| chatbot[-1] = (chatbot[-1][0], ai_response) | |
| self.memory(inputs, ai_response) | |
| # self.auto_save_history(chatbot) | |
| yield chatbot, status_text | |
| else: | |
| status_text = "Retrieving information from Google" | |
| yield chatbot, status_text | |
| search = GoogleSearchAPIWrapper(google_api_key=GOOGLE_API_KEY, google_cse_id=GOOGLE_CSE_ID) | |
| results = search.results(inputs, num_results=top_k) | |
| reference_results, display_append = web_citation(inputs, results, False) | |
| reference_results = add_source_numbers(reference_results) | |
| display_append = '<div class = "source-a">' + "\n".join(display_append) + '</div>' | |
| status_text = "Request URL: " + OPENAI_API_BASE | |
| yield chatbot, status_text | |
| chatbot.append((inputs, "")) | |
| web_search = GoogleWebSearch() | |
| ai_response = web_search.predict(context="\n\n".join(reference_results), question=inputs, | |
| chat_history=self.history) | |
| chatbot[-1] = (chatbot[-1][0], ai_response + display_append) | |
| self.memory(inputs, ai_response) | |
| # self.auto_save_history(chatbot) | |
| yield chatbot, status_text | |
| if __name__ == '__main__': | |
| import os | |
| from config import OPENAI_API_KEY | |
| from langchain.chains.llm import LLMChain | |
| from langchain.prompts.chat import ( | |
| ChatPromptTemplate, | |
| SystemMessagePromptTemplate, | |
| HumanMessagePromptTemplate) | |
| SYSTEM_PROMPT_TEMPLATE = "You're a helpful assistant." | |
| HUMAN_PROMPT_TEMPLATE = "Human: {question}\n AI answer:" | |
| prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT_TEMPLATE), | |
| HumanMessagePromptTemplate.from_template(HUMAN_PROMPT_TEMPLATE) | |
| ] | |
| ) | |
| llm = CustomAzureOpenAI(deployment_name="binh-gpt", | |
| openai_api_key=OPENAI_API_KEY, | |
| openai_api_base=OPENAI_API_BASE, | |
| openai_api_version=OPENAI_API_VERSION, | |
| temperature=0, | |
| model_kwargs={"top_p": 1.0}, ) | |
| llm_chain = LLMChain( | |
| llm=llm, | |
| prompt=prompt | |
| ) | |
| results = llm_chain.predict(question="Hello") | |
| print(results) |