Spaces:
Runtime error
Runtime error
| import glob | |
| import json | |
| import multiprocessing | |
| import os | |
| import re | |
| import shutil | |
| import sys | |
| import traceback | |
| from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor | |
| from functools import partial | |
| import torch | |
| from FlagEmbedding import BGEM3FlagModel | |
| from jinja2 import Template | |
| from tqdm import tqdm | |
| os.environ['OPENAI_API_KEY'] = 'Your key here' | |
| root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')) | |
| sys.path.insert(0, root_dir) | |
| import src.llms as llms | |
| from src.induct import SlideInducter | |
| from src.model_utils import ( | |
| get_image_embedding, | |
| get_image_model, | |
| images_cosine_similarity, | |
| parse_pdf, | |
| prs_dedup, | |
| ) | |
| from src.multimodal import ImageLabler | |
| from src.presentation import Picture, Presentation, SlidePage | |
| from src.utils import Config, older_than, pexists, pjoin, ppt_to_images | |
| markdown_clean_pattern = re.compile(r"!\[.*?\]\((.*?)\)") | |
| device_count = torch.cuda.device_count() | |
| def rm_folder(folder: str): | |
| try: | |
| shutil.rmtree(folder) | |
| except: | |
| for i in os.listdir(folder): | |
| try: | |
| rm_folder(pjoin(folder, i)) | |
| except: | |
| pass | |
| def process_filetype(file_type: str, func: callable, thread_num: int, topic="*"): | |
| folders = glob.glob(f"data/{topic}/{file_type}/*") | |
| progress_bar = tqdm(total=len(folders), desc=f"processing {file_type}") | |
| def process_folder(folder, *args, **kwargs): | |
| try: | |
| func(folder, *args, **kwargs) | |
| except Exception as e: | |
| print(f"process {file_type} folder {folder} failed: {e}") | |
| traceback.print_exc() | |
| finally: | |
| progress_bar.update(1) | |
| with ThreadPoolExecutor(thread_num) as executor: | |
| list(executor.map(process_folder, folders, range(len(folders)))) | |
| progress_bar.close() | |
| def parse_pdfs(pdf_folders: list[str], idx: int): | |
| # require numpy==1.26.0, which is conflict with other packages | |
| from marker.models import create_model_dict | |
| model = create_model_dict(device=idx % device_count, dtype=torch.float16) | |
| for pdf_folder in pdf_folders: | |
| if not older_than(pdf_folder + "/original.pdf"): | |
| continue | |
| if not pexists(pjoin(pdf_folder, "source.md")): | |
| text_content = parse_pdf( | |
| pdf_folder + "/original.pdf", | |
| pdf_folder, | |
| model, | |
| ) | |
| if len(text_content) < 512: | |
| rm_folder(pdf_folder) | |
| continue | |
| def prepare_pdf_folder(pdf_folder: str, rank: int): | |
| image_model = get_image_model(f"cuda:{rank % device_count}") | |
| if not pexists(pjoin(pdf_folder, "source.md")): | |
| return | |
| if not pexists(pjoin(pdf_folder, "image_caption.json")): | |
| images_embeddings = get_image_embedding(pdf_folder, *image_model) | |
| images = [pjoin(pdf_folder, image) for image in images_embeddings] | |
| if len(images_embeddings) == 0: | |
| rm_folder(pdf_folder) | |
| return | |
| similarity_matrix = images_cosine_similarity(list(images_embeddings.values())) | |
| for i in range(len(similarity_matrix)): | |
| for j in range(i + 1, len(similarity_matrix)): | |
| if similarity_matrix[i][j] > 0.85: | |
| if pexists(images[i]): | |
| os.remove(images[i]) | |
| break | |
| images = [image for image in images if pexists(image)] | |
| image_stats = {} | |
| caption_prompt = open("prompts/caption.txt").read() | |
| for image in images: | |
| image_stats[image] = llms.vision_model(caption_prompt, image) | |
| print(image_stats[image]) | |
| with open(pjoin(pdf_folder, "image_caption.json"), mode="w") as f: | |
| json.dump(image_stats, f, indent=4, ensure_ascii=False) | |
| if not pexists(pjoin(pdf_folder, "refined_doc.json")): | |
| text_content = open(pjoin(pdf_folder, "source.md")).read() | |
| text_content = markdown_clean_pattern.sub("", text_content) | |
| template = Template(open("prompts/document_refine.txt").read()) | |
| doc_json = llms.language_model( | |
| template.render(markdown_document=text_content), return_json=True | |
| ) | |
| json.dump( | |
| doc_json, | |
| open(pjoin(pdf_folder, "refined_doc.json"), "w"), | |
| indent=4, | |
| ensure_ascii=False, | |
| ) | |
| def filter_slide(slide: SlidePage): | |
| num_pictures = len(list(slide.shape_filter(Picture))) | |
| num_shapes = len(slide.shapes) | |
| if num_shapes > 10: | |
| return True | |
| if num_shapes - num_pictures < 2: | |
| return True | |
| if slide.real_idx != 0 and num_pictures > 2: | |
| return True | |
| def I_dont_want_to_filter_slide(slide: SlidePage): | |
| return False | |
| def check_consistency(slides: list[SlidePage], ppt_folder: str, image_model): | |
| original_embeddings = get_image_embedding( | |
| pjoin(ppt_folder, "original_slides"), *image_model | |
| ) | |
| rebuild_embeddings = get_image_embedding( | |
| pjoin(ppt_folder, "source_slides"), *image_model | |
| ) | |
| for slide in slides: | |
| if ( | |
| torch.cosine_similarity( | |
| original_embeddings[f"slide_{slide.real_idx:04d}.jpg"], | |
| rebuild_embeddings[f"slide_{slide.slide_idx:04d}.jpg"], | |
| dim=-1, | |
| ) | |
| < 0.9 | |
| ): | |
| raise ValueError(f"slide {slide.real_idx} in {ppt_folder} is inconsistent") | |
| return True | |
| def prepare_ppt_folder(ppt_folder: str, text_model: BGEM3FlagModel, image_model): | |
| if pexists(ppt_folder + "/source.pptx") or not older_than( | |
| ppt_folder + "/original.pptx" | |
| ): | |
| return | |
| config = Config(rundir=ppt_folder, debug=False) | |
| presentation = Presentation.from_file(ppt_folder + "/original.pptx", config=config) | |
| if not os.path.exists(pjoin(ppt_folder, "original_slides")): | |
| ppt_to_images(presentation.source_file, pjoin(ppt_folder, "original_slides")) | |
| ppt_image_folder = pjoin(ppt_folder, "source_slides") | |
| shutil.rmtree(ppt_image_folder, ignore_errors=True) | |
| shutil.copytree(pjoin(ppt_folder, "original_slides"), ppt_image_folder) | |
| removed_slides = prs_dedup(presentation, text_model) | |
| for slide in [slide for slide in presentation.slides if I_dont_want_to_filter_slide(slide)]: | |
| removed_slides.append(slide) | |
| presentation.slides.remove(slide) | |
| for slide in removed_slides: | |
| os.remove(pjoin(ppt_image_folder, f"slide_{slide.real_idx:04d}.jpg")) | |
| for err_idx, _ in presentation.error_history: | |
| os.remove(pjoin(ppt_image_folder, f"slide_{err_idx:04d}.jpg")) | |
| assert len(presentation) == len( | |
| [i for i in os.listdir(ppt_image_folder) if i.endswith(".jpg")] | |
| ) | |
| for i, slide in enumerate(presentation.slides, 1): | |
| slide.slide_idx = i | |
| os.rename( | |
| pjoin(ppt_image_folder, f"slide_{slide.real_idx:04d}.jpg"), | |
| pjoin(ppt_image_folder, f"slide_{slide.slide_idx:04d}.jpg"), | |
| ) | |
| check_consistency(presentation.slides, ppt_folder, image_model) | |
| ImageLabler(presentation, config).caption_images() | |
| presentation.save(pjoin(ppt_folder, "source.pptx")) | |
| presentation.save(pjoin(ppt_folder, "template.pptx"), layout_only=True) | |
| ppt_to_images( | |
| pjoin(ppt_folder, "template.pptx"), | |
| pjoin(ppt_folder, "template_images"), | |
| ) | |
| os.remove(pjoin(ppt_folder, "template.pptx")) | |
| def prepare_induction(induct_id: int, wait: bool = False): | |
| induct_llms = [ | |
| (llms.qwen2_5, llms.qwen_vl), | |
| (llms.gpt4o, llms.gpt4o), | |
| (llms.qwen_vl, llms.qwen_vl), | |
| ] | |
| def do_induct(llm: list[llms.LLM], ppt_folder: str, rank: int): | |
| if not older_than(pjoin(ppt_folder, "source.pptx"), wait=wait): | |
| return | |
| llms.language_model = llm[0] | |
| llms.vision_model = llm[1] | |
| config = Config(rundir=ppt_folder) | |
| ppt_image_folder = pjoin(ppt_folder, "source_slides") | |
| template_image_folder = pjoin(ppt_folder, "template_images") | |
| image_model = get_image_model(f"cuda:{rank % device_count}") | |
| presentation = Presentation.from_file(pjoin(ppt_folder, "source.pptx"), config) | |
| ImageLabler(presentation, config).caption_images() | |
| slide_inducter = SlideInducter( | |
| presentation, ppt_image_folder, template_image_folder, config, image_model | |
| ) | |
| slide_inducter.content_induct() | |
| for folder in tqdm(sorted(glob.glob("data/*/pptx/*")), desc="prepare induction"): | |
| do_induct(induct_llms[induct_id], folder, 0) | |
| if __name__ == "__main__": | |
| if sys.argv[1] == "prepare_ppt": | |
| text_model = BGEM3FlagModel("BAAI/bge-m3", use_fp16=True, device=0) | |
| image_model = get_image_model(0) | |
| for ppt_folder in tqdm(glob.glob("data/*/pptx/*"), desc="prepare ppt"): | |
| prepare_ppt_folder(ppt_folder, text_model, image_model) | |
| elif sys.argv[1] == "prepare_induction": | |
| prepare_induction(int(sys.argv[2])) | |
| elif sys.argv[1] == "parse_pdf": | |
| multiprocessing.set_start_method("spawn", force=True) | |
| num_process = int(sys.argv[2]) | |
| with ProcessPoolExecutor(max_workers=num_process) as executor: | |
| folders = glob.glob("data/*/pdf/*") | |
| subfolders = [[] for _ in range(num_process)] | |
| for idx, folder in enumerate(folders): | |
| subfolders[idx % num_process].append(folder) | |
| list(executor.map(parse_pdfs, subfolders, range(num_process))) | |
| elif sys.argv[1] == "prepare_pdf": | |
| prepare_pdf_folder = partial(prepare_pdf_folder) | |
| process_filetype("pdf", prepare_pdf_folder, int(sys.argv[2])) | |