Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import logging | |
| import hashlib | |
| import pandas as pd | |
| from .gpt_processor import ( | |
| EmbeddingGenerator, | |
| KeywordsGenerator, | |
| Summarizer, | |
| TopicsGenerator, | |
| Translator, | |
| ) | |
| from .pdf_processor import PDFProcessor | |
| processors = { | |
| "pdf": PDFProcessor, | |
| } | |
| class WorkFlowController: | |
| def __init__(self, file_src, uid) -> None: | |
| # check if the file_path is list | |
| # self.file_paths = self.__get_file_name(file_src) | |
| self.file_paths = [x.name for x in file_src] | |
| self.uid = uid | |
| print(self.file_paths) | |
| self.files_info = {} | |
| for file_path in self.file_paths: | |
| file_name = file_path.split("/")[-1] | |
| file_format = file_path.split(".")[-1] | |
| self.file_processor = processors[file_format] | |
| file = self.file_processor(file_path).file_info | |
| file = self.__process_file(file) | |
| self.files_info[file_name] = file | |
| self.__dump_to_json() | |
| self.__dump_to_csv() | |
| def __get_summary(self, file: dict): | |
| # get summary from file content | |
| summarizer = Summarizer() | |
| file["summarized_content"] = summarizer.summarize(file["file_full_content"]) | |
| return file | |
| def __get_keywords(self, file: dict): | |
| # get keywords from file content | |
| keywords_generator = KeywordsGenerator() | |
| file["keywords"] = keywords_generator.extract_keywords( | |
| file["file_full_content"] | |
| ) | |
| return file | |
| def __get_topics(self, file: dict): | |
| # get topics from file content | |
| topics_generator = TopicsGenerator() | |
| file["topics"] = topics_generator.extract_topics(file["file_full_content"]) | |
| return file | |
| def __get_embedding(self, file): | |
| # get embedding from file content | |
| # return embedding | |
| embedding_generator = EmbeddingGenerator() | |
| for i, _ in enumerate(file["file_content"]): | |
| # use i+1 to meet the index of file_content | |
| file["file_content"][i + 1][ | |
| "page_embedding" | |
| ] = embedding_generator.get_embedding( | |
| file["file_content"][i + 1]["page_content"] | |
| ) | |
| return file | |
| def __translate_to_chinese(self, file: dict): | |
| # translate file content to chinese | |
| translator = Translator() | |
| # reset the file full content | |
| file["file_full_content"] = "" | |
| for i, _ in enumerate(file["file_content"]): | |
| # use i+1 to meet the index of file_content | |
| print("Translating page: " + str(i + 1)) | |
| file["file_content"][i + 1][ | |
| "page_content" | |
| ] = translator.translate_to_chinese( | |
| file["file_content"][i + 1]["page_content"] | |
| ) | |
| file["file_full_content"] = ( | |
| file["file_full_content"] + file["file_content"][i + 1]["page_content"] | |
| ) | |
| return file | |
| def __process_file(self, file: dict): | |
| # process file content | |
| # return processed data | |
| if not file["is_chinese"]: | |
| print("Translating to chinese...") | |
| file = self.__translate_to_chinese(file) | |
| print("Getting embedding...") | |
| file = self.__get_embedding(file) | |
| print("Getting summary...") | |
| file = self.__get_summary(file) | |
| return file | |
| def __dump_to_json(self): | |
| with open( | |
| os.path.join(os.getcwd(), f"{self.uid}_knowledge_base.json"), | |
| "w", | |
| encoding="utf-8", | |
| ) as f: | |
| print( | |
| "Dumping to json, the path is: " | |
| + os.path.join(os.getcwd(), f"{self.uid}_knowledge_base.json") | |
| ) | |
| self.json_result_path = os.path.join( | |
| os.getcwd(), f"{self.uid}_knowledge_base.json" | |
| ) | |
| json.dump(self.files_info, f, indent=4, ensure_ascii=False) | |
| def __construct_knowledge_base_dataframe(self): | |
| rows = [] | |
| for file_path, content in self.files_info.items(): | |
| for page_num, page_details in content["file_content"].items(): | |
| row = { | |
| "file_name": content["file_name"], | |
| "page_num": page_details["page_num"], | |
| "page_content": page_details["page_content"], | |
| "page_embedding": page_details["page_embedding"], | |
| } | |
| rows.append(row) | |
| columns = [ | |
| "file_name", | |
| "page_num", | |
| "page_content", | |
| "page_embedding", | |
| ] | |
| df = pd.DataFrame(rows, columns=columns) | |
| return df | |
| def __dump_to_csv(self): | |
| df = self.__construct_knowledge_base_dataframe() | |
| df.to_csv( | |
| os.path.join(os.getcwd(), f"{self.uid}_knowledge_base.csv"), index=False | |
| ) | |
| print( | |
| "Dumping to csv, the path is: " | |
| + os.path.join(os.getcwd(), f"{self.uid}_knowledge_base.csv") | |
| ) | |
| self.csv_result_path = os.path.join( | |
| os.getcwd(), f"{self.uid}_knowledge_base.csv" | |
| ) | |
| def __get_file_name(self, file_src): | |
| file_paths = [x.name for x in file_src] | |
| file_paths.sort(key=lambda x: os.path.basename(x)) | |
| md5_hash = hashlib.md5() | |
| for file_path in file_paths: | |
| with open(file_path, "rb") as f: | |
| while chunk := f.read(8192): | |
| md5_hash.update(chunk) | |
| return md5_hash.hexdigest() | |