Spaces:
Runtime error
Runtime error
| import re | |
| import os | |
| import logging | |
| from typing import List | |
| from opencc import OpenCC | |
| import openai | |
| import tiktoken | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| class GPTAgent: | |
| def __init__(self, model): | |
| openai.api_key = OPENAI_API_KEY | |
| self.model = model | |
| self.temperature = 0.8 | |
| self.frequency_penalty = 0 | |
| self.presence_penalty = 0.6 | |
| self.max_tokens = 2048 | |
| self.split_max_tokens = 13000 | |
| def request(self, messages): | |
| response = self.agent.complete(messages=messages) | |
| return response.choices[0].message["content"] | |
| def split_into_many(self, text): | |
| tokenizer = tiktoken.get_encoding("cl100k_base") | |
| sentences = text.split("。") | |
| n_tokens = [len(tokenizer.encode(" " + sentence)) for sentence in sentences] | |
| chunks = [] | |
| tokens_so_far = 0 | |
| chunk = [] | |
| for sentence, token in zip(sentences, n_tokens): | |
| if tokens_so_far + token > 3000: | |
| chunks.append("。".join(chunk) + "。") | |
| chunk = [] | |
| tokens_so_far = 0 | |
| if token > 3000: | |
| continue | |
| chunk.append(sentence) | |
| tokens_so_far += token + 1 | |
| chunks.append("。".join(chunk) + "。") | |
| return [text] if len(chunks) == 0 else chunks | |
| def preprocess(self, text): | |
| text = text.replace("\n", " ").replace("\r", "") | |
| return text | |
| def parse_result(self, result): | |
| parsed_result = [] | |
| chinese_converter = OpenCC("s2tw") | |
| for i in range(len(result)): | |
| result[i] = result[i].split(",") | |
| if len(result[i]) == 1: | |
| result[i] = result[i][0].split("、") | |
| if len(result[i]) == 1: | |
| result[i] = result[i][0].split(",") | |
| for word in result[i]: | |
| try: | |
| parsed_result.append( | |
| chinese_converter.convert(word).strip().replace("。", "") | |
| ) | |
| except Exception as e: | |
| logging.error(e) | |
| logging.error("Failed to parse result") | |
| return parsed_result | |
| class Translator(GPTAgent): | |
| def __init__(self): | |
| super().__init__("gpt-3.5-turbo") | |
| def translate_to_chinese(self, text): | |
| system_prompt = """ | |
| I want you to act as an Chinese translator, spelling corrector and improver. | |
| I will speak to you in English, translate it and answer in the corrected and improved version of my text, in Traditional Chinese. | |
| Keep the meaning same, but make them more literary. I want you to only reply the correction, the improvements and nothing else, do not write explanations and DO NOT use any Simplified Chinese. | |
| """ | |
| system_prompt_zh_tw = """ | |
| 我希望你擔任中文翻譯、拼寫糾正及改進的角色。 | |
| 我將用英文與你交流,請將其翻譯並用繁體中文回答,同時對我的文本進行糾正和改進。 | |
| 保持原意不變,但使其更具文學性。我希望你僅回覆更正、改進的部分,不要寫解釋,也不要使用任何简体中文。 | |
| """ | |
| messages = [ | |
| {"role": "system", "content": f"{system_prompt_zh_tw}"}, | |
| {"role": "user", "content": text}, | |
| ] | |
| try: | |
| response = openai.ChatCompletion.create( | |
| model=self.model, | |
| messages=messages, | |
| temperature=self.temperature, | |
| frequency_penalty=self.frequency_penalty, | |
| presence_penalty=self.presence_penalty, | |
| ) | |
| except Exception as e: | |
| logging.error(e) | |
| logging.error("Failed to translate to Chinese") | |
| # translate from simplified chinese to traditional chinese | |
| chinese_converter = OpenCC("s2tw") | |
| return chinese_converter.convert( | |
| response["choices"][0]["message"]["content"].strip() | |
| ) | |
| class EmbeddingGenerator(GPTAgent): | |
| def __init__(self): | |
| super().__init__("text-davinci-002") | |
| def get_embedding(self, text): | |
| return openai.Embedding.create(input=text, engine="text-embedding-ada-002")[ | |
| "data" | |
| ][0]["embedding"] | |
| class KeywordsGenerator(GPTAgent): | |
| def __init__(self): | |
| super().__init__("gpt-3.5-turbo") | |
| def extract_keywords(self, text): | |
| system_prompt = """ | |
| 請你為以下內容抓出 5 個關鍵字用以搜尋這篇文章,並用「,」來分隔 | |
| """ | |
| text_chunks = self.split_into_many(text) | |
| keywords = [] | |
| for i in range(len(text_chunks)): | |
| text = text_chunks[i] | |
| messages = [ | |
| {"role": "system", "content": f"{system_prompt}"}, | |
| {"role": "user", "content": f"{self.preprocess(text)}"}, | |
| ] | |
| try: | |
| response = openai.ChatCompletion.create( | |
| model=self.model, | |
| messages=messages, | |
| temperature=0, | |
| max_tokens=self.max_tokens, | |
| frequency_penalty=self.frequency_penalty, | |
| presence_penalty=self.presence_penalty, | |
| ) | |
| keywords.append(response["choices"][0]["message"]["content"].strip()) | |
| except Exception as e: | |
| logging.error(e) | |
| logging.error("Failed to extract keywords") | |
| return self.parse_result(keywords) | |
| class TopicsGenerator(GPTAgent): | |
| def __init__(self): | |
| super().__init__("gpt-3.5-turbo") | |
| def extract_topics(self, text): | |
| system_prompt = """ | |
| 請你為以下內容給予 3 個高度抽象的主題分類這篇文章,並用「,」來分隔 | |
| """ | |
| text_chunks = self.split_into_many(text) | |
| topics = [] | |
| for i in range(len(text_chunks)): | |
| text = text_chunks[i] | |
| messages = [ | |
| {"role": "system", "content": f"{system_prompt}"}, | |
| {"role": "user", "content": f"{self.preprocess(text)}"}, | |
| ] | |
| try: | |
| response = openai.ChatCompletion.create( | |
| model=self.model, | |
| messages=messages, | |
| temperature=0, | |
| max_tokens=self.max_tokens, | |
| frequency_penalty=self.frequency_penalty, | |
| presence_penalty=self.presence_penalty, | |
| ) | |
| topics.append(response["choices"][0]["message"]["content"].strip()) | |
| except Exception as e: | |
| logging.error(e) | |
| logging.error("Failed to extract topics") | |
| return self.parse_result(topics) | |
| class Summarizer(GPTAgent): | |
| def __init__(self): | |
| super().__init__("gpt-3.5-turbo-16k") | |
| def summarize(self, text): | |
| system_prompt = """ | |
| 請幫我總結以下的文章。 | |
| """ | |
| text_chunks = self.split_into_many(text) | |
| if len(text_chunks) > 1: | |
| concated_summary = "" | |
| for i in range(len(text_chunks)): | |
| text_chunk = text[i].replace("\n", " ").replace("\r", "") | |
| messages = [ | |
| {"role": "system", "content": f"{system_prompt}"}, | |
| {"role": "user", "content": text_chunk}, | |
| ] | |
| try: | |
| response = openai.ChatCompletion.create( | |
| model=self.model, | |
| messages=messages, | |
| temperature=self.temperature, | |
| max_tokens=self.max_tokens, | |
| frequency_penalty=self.frequency_penalty, | |
| presence_penalty=self.presence_penalty, | |
| ) | |
| except Exception as e: | |
| logging.error(e) | |
| logging.error("Failed to summarize text_chunk") | |
| chinese_converter = OpenCC("s2tw") | |
| concated_summary += chinese_converter.convert( | |
| response["choices"][0]["message"]["content"].strip() | |
| ) | |
| # summarize concated_summary | |
| messages = [ | |
| {"role": "system", "content": f"{system_prompt}"}, | |
| {"role": "user", "content": concated_summary}, | |
| ] | |
| try: | |
| response = openai.ChatCompletion.create( | |
| model=self.model, | |
| messages=messages, | |
| temperature=self.temperature, | |
| max_tokens=self.max_tokens, | |
| frequency_penalty=self.frequency_penalty, | |
| presence_penalty=self.presence_penalty, | |
| ) | |
| except Exception as e: | |
| logging.error(e) | |
| logging.error("Failed to summarize concated_summary") | |
| chinese_converter = OpenCC("s2tw") | |
| return chinese_converter.convert( | |
| response["choices"][0]["message"]["content"].strip() | |
| ) | |
| else: | |
| messages = [ | |
| {"role": "system", "content": f"{system_prompt}"}, | |
| {"role": "user", "content": text}, | |
| ] | |
| try: | |
| response = openai.ChatCompletion.create( | |
| model=self.model, | |
| messages=messages, | |
| temperature=self.temperature, | |
| max_tokens=self.max_tokens, | |
| frequency_penalty=self.frequency_penalty, | |
| presence_penalty=self.presence_penalty, | |
| ) | |
| except Exception as e: | |
| logging.error(e) | |
| logging.error("Failed to summarize") | |
| chinese_converter = OpenCC("s2tw") | |
| print(f'the summary is {response["choices"][0]["message"]["content"].strip()}') | |
| response = chinese_converter.convert( | |
| response["choices"][0]["message"]["content"] | |
| ) | |
| return response | |
| class QuestionAnswerer(GPTAgent): | |
| def __init__(self): | |
| super().__init__("gpt-3.5-turbo-16k") | |
| def answer_chunk_question(self, text, question): | |
| system_prompt = """ | |
| 你是一個知識檢索系統,我會給你一份文件,請幫我依照文件內容回答問題,並用繁體中文回答。以下是文件內容 | |
| """ | |
| text_chunks = self.split_into_many(text) | |
| answer_chunks = [] | |
| for i in range(len(text_chunks)): | |
| text = text_chunks[i] | |
| messages = [ | |
| {"role": "system", "content": f"{system_prompt} + '\n' '{text}'"}, | |
| {"role": "user", "content": f"{question}"}, | |
| ] | |
| try: | |
| response = openai.ChatCompletion.create( | |
| model=self.model, | |
| messages=messages, | |
| temperature=self.temperature, | |
| max_tokens=1024, | |
| frequency_penalty=self.frequency_penalty, | |
| presence_penalty=self.presence_penalty, | |
| ) | |
| except Exception as e: | |
| logging.error(e) | |
| logging.error("Failed to answer question") | |
| chinese_converter = OpenCC("s2tw") | |
| answer_chunks.append( | |
| chinese_converter.convert( | |
| response["choices"][0]["message"]["content"].strip() | |
| ) | |
| ) | |
| return "。".join(answer_chunks) | |
| def answer_question(self, context, context_page_num, context_file_name, history): | |
| system_prompt = """ | |
| 你是一個知識檢索系統,我會給你一份文件,請幫我依照文件內容回答問題,並用繁體中文回答。以下是文件內容 | |
| """ | |
| history = self.__construct_message_history(history) | |
| messages = [ | |
| {"role": "system", "content": f"{system_prompt} + '\n' '''{context}'''"}, | |
| ] + history | |
| try: | |
| response = openai.ChatCompletion.create( | |
| model=self.model, | |
| messages=messages, | |
| temperature=self.temperature, | |
| max_tokens=2048, | |
| frequency_penalty=self.frequency_penalty, | |
| presence_penalty=self.presence_penalty, | |
| ) | |
| chinese_converter = OpenCC("s2tw") | |
| page_num_message = f"以下內容來自 {context_file_name},第 {context_page_num} 頁\n\n" | |
| bot_answer = response["choices"][0]["message"]["content"] | |
| whole_answer = page_num_message + bot_answer | |
| return chinese_converter.convert(whole_answer) | |
| except Exception as e: | |
| logging.error(e) | |
| logging.error("Failed to answer question") | |
| def __construct_message_history(self, history): | |
| print(f"history is {history}") | |
| max_history_length = 10 | |
| if len(history) > max_history_length: | |
| history = history[-max_history_length:] | |
| messages = [] | |
| for i in range(len(history)): | |
| messages.append({"role": "user", "content": history[i][0]}) | |
| if history[i][1] is not None: | |
| messages.append({"role": "assistant", "content": history[i][1]}) | |
| return messages |