Spaces:
Build error
Build error
| from dotenv import load_dotenv | |
| from utils.src.utils import ppt_to_images, get_json_from_response | |
| import json | |
| from camel.models import ModelFactory | |
| from camel.agents import ChatAgent | |
| from utils.wei_utils import * | |
| from camel.messages import BaseMessage | |
| from PIL import Image | |
| import pickle as pkl | |
| from utils.pptx_utils import * | |
| from utils.critic_utils import * | |
| import yaml | |
| import argparse | |
| import shutil | |
| from jinja2 import Environment, StrictUndefined | |
| from concurrent.futures import ThreadPoolExecutor | |
| import copy | |
| load_dotenv() | |
| MAX_ATTEMPTS = 5 | |
| def process_leaf_section( | |
| leaf_section, | |
| section_name, | |
| outline, | |
| content, | |
| style_logs, | |
| critic_logs, | |
| actor_logs, | |
| img_logs, | |
| slide_width, | |
| slide_height, | |
| name_to_hierarchy, | |
| critic_template, | |
| actor_template, | |
| critic_agent, | |
| actor_agent, | |
| neg_img, | |
| pos_img, | |
| MAX_ATTEMPTS, | |
| documentation, | |
| total_input_token, | |
| total_output_token, | |
| ): | |
| """ | |
| Handles the logic for a single leaf_section within a section_name. | |
| Returns a dictionary of updated logs and tokens. | |
| """ | |
| section_code = style_logs[section_name][-1]['code'] # current code for this section | |
| log = [] | |
| leaf_name = None | |
| if leaf_section in outline: | |
| leaf_name = outline[leaf_section]['name'] | |
| else: | |
| leaf_name = outline[section_name]['subsections'][leaf_section]['name'] | |
| num_rounds = 0 | |
| while True: | |
| print(f"Section: {section_name}, Leaf Section: {leaf_section}, Round: {num_rounds}") | |
| num_rounds += 1 | |
| if num_rounds > MAX_ATTEMPTS: | |
| break | |
| poster = create_poster(slide_width, slide_height) | |
| add_blank_slide(poster) | |
| empty_poster_path = f'tmp/empty_poster_{section_name}_{leaf_section}.pptx' | |
| save_presentation(poster, file_name=empty_poster_path) | |
| curr_location, zoomed_in_img, zoomed_in_img_path = get_snapshot_from_section( | |
| leaf_section, | |
| section_name, | |
| name_to_hierarchy, | |
| leaf_name, | |
| section_code, | |
| empty_poster_path | |
| ) | |
| if leaf_section not in img_logs: | |
| img_logs[leaf_section] = [] | |
| img_logs[leaf_section].append(zoomed_in_img) | |
| jinja_args = { | |
| 'content_json': content[leaf_section] if leaf_section in content | |
| else content[section_name]['subsections'][leaf_section], | |
| 'existing_code': section_code, | |
| } | |
| critic_prompt = critic_template.render(**jinja_args) | |
| critic_msg = BaseMessage.make_user_message( | |
| role_name="User", | |
| content=critic_prompt, | |
| image_list=[neg_img, pos_img, zoomed_in_img], | |
| ) | |
| critic_agent.reset() | |
| response = critic_agent.step(critic_msg) | |
| resp = response.msgs[0].content | |
| # Track tokens | |
| input_token, output_token = account_token(response) | |
| total_input_token += input_token | |
| total_output_token += output_token | |
| if leaf_section not in critic_logs: | |
| critic_logs[leaf_section] = [] | |
| critic_logs[leaf_section].append(response) | |
| # Stop condition | |
| if isinstance(resp, str): | |
| if resp in ['NO', 'NO.', '"NO"', "'NO'"]: | |
| break | |
| feedback = get_json_from_response(resp) | |
| print(feedback) | |
| jinja_args = { | |
| 'content_json': content[leaf_section] if leaf_section in content | |
| else content[section_name]['subsections'][leaf_section], | |
| 'function_docs': documentation, | |
| 'existing_code': section_code, | |
| 'suggestion_json': feedback, | |
| } | |
| actor_prompt = actor_template.render(**jinja_args) | |
| leaf_log = edit_code(actor_agent, actor_prompt, 3, existing_code='') | |
| if leaf_log[-1]['error'] is not None: | |
| raise Exception(leaf_log[-1]['error']) | |
| # Track tokens | |
| in_tok = leaf_log[-1]['cumulative_tokens'][0] | |
| out_tok = leaf_log[-1]['cumulative_tokens'][1] | |
| total_input_token += in_tok | |
| total_output_token += out_tok | |
| section_code = leaf_log[-1]['code'] | |
| if leaf_section not in actor_logs: | |
| actor_logs[leaf_section] = [] | |
| actor_logs[leaf_section].append(leaf_log) | |
| log.extend(leaf_log) | |
| return { | |
| "section_code": section_code, | |
| "log": log, | |
| "img_logs": img_logs, | |
| "critic_logs": critic_logs, | |
| "actor_logs": actor_logs, | |
| "total_input_token": total_input_token, | |
| "total_output_token": total_output_token, | |
| } | |
| def process_section( | |
| section_name, | |
| content, | |
| outline, | |
| sections, | |
| style_logs, | |
| critic_logs, | |
| actor_logs, | |
| img_logs, | |
| slide_width, | |
| slide_height, | |
| name_to_hierarchy, | |
| critic_template, | |
| actor_template, | |
| critic_agent, | |
| actor_agent, | |
| neg_img, | |
| pos_img, | |
| MAX_ATTEMPTS, | |
| documentation, | |
| total_input_token, | |
| total_output_token, | |
| ): | |
| """ | |
| Handles processing of a single section and its subsections (leaf sections). | |
| Returns updated logs and token counters for this section. | |
| """ | |
| results_per_leaf = [] | |
| # Grab the current code for this section | |
| section_code = style_logs[section_name][-1]['code'] | |
| # Determine which leaf sections to process | |
| if 'subsections' in content[section_name]: | |
| subsections = list(content[section_name]['subsections'].keys()) | |
| else: | |
| subsections = [section_name] | |
| all_logs_for_section = [] | |
| for leaf_section in subsections: | |
| # Process this leaf section | |
| leaf_result = process_leaf_section( | |
| leaf_section, | |
| section_name, | |
| outline, | |
| content, | |
| style_logs, | |
| critic_logs, | |
| actor_logs, | |
| img_logs, | |
| slide_width, | |
| slide_height, | |
| name_to_hierarchy, | |
| critic_template, | |
| actor_template, | |
| critic_agent, | |
| actor_agent, | |
| neg_img, | |
| pos_img, | |
| MAX_ATTEMPTS, | |
| documentation, | |
| total_input_token, | |
| total_output_token, | |
| ) | |
| # Update logs/tokens | |
| section_code = leaf_result["section_code"] | |
| all_logs_for_section.extend(leaf_result["log"]) | |
| img_logs = leaf_result["img_logs"] | |
| critic_logs = leaf_result["critic_logs"] | |
| actor_logs = leaf_result["actor_logs"] | |
| total_input_token = leaf_result["total_input_token"] | |
| total_output_token = leaf_result["total_output_token"] | |
| # If we have any logs from the last leaf in this section, append them | |
| if all_logs_for_section: | |
| style_logs[section_name].append(all_logs_for_section[-1]) | |
| # Return updated state for merging back in the main thread | |
| return { | |
| "section_name": section_name, | |
| "style_logs": style_logs, | |
| "critic_logs": critic_logs, | |
| "actor_logs": actor_logs, | |
| "img_logs": img_logs, | |
| "total_input_token": total_input_token, | |
| "total_output_token": total_output_token | |
| } | |
| def parallel_by_sections( | |
| sections, | |
| content, | |
| outline, | |
| style_logs, | |
| critic_logs, | |
| actor_logs, | |
| img_logs, | |
| slide_width, | |
| slide_height, | |
| name_to_hierarchy, | |
| critic_template, | |
| actor_template, | |
| critic_agent, | |
| actor_agent, | |
| neg_img, | |
| pos_img, | |
| MAX_ATTEMPTS, | |
| documentation, | |
| total_input_token, | |
| total_output_token, | |
| max_workers=4 | |
| ): | |
| """ | |
| Main entry point to parallelize processing across sections. | |
| Returns the merged logs and token counters after processing all sections in parallel. | |
| """ | |
| # Because we’ll be modifying dictionaries (like style_logs, etc.), | |
| # it can be safer to create a copy for the workers, then merge results | |
| # after. (Below is a simple approach—depending on your scale, consider | |
| # explicit concurrency controls or a database-backed approach.) | |
| # Summaries from each future | |
| results = [] | |
| # We’ll store fresh copies for each section to avoid concurrency collisions | |
| # on dictionary updates. If the data is large, you might want a more | |
| # sophisticated synchronization or partition approach rather than naive copies. | |
| with ThreadPoolExecutor(max_workers=max_workers) as executor: | |
| futures = [] | |
| for section_name in sections: | |
| # Make shallow copies or deep copies of logs | |
| _style_logs = copy.deepcopy(style_logs) | |
| _critic_logs = copy.deepcopy(critic_logs) | |
| _actor_logs = copy.deepcopy(actor_logs) | |
| _img_logs = copy.deepcopy(img_logs) | |
| futures.append(executor.submit( | |
| process_section, | |
| section_name, | |
| content, | |
| outline, | |
| sections, | |
| _style_logs, | |
| _critic_logs, | |
| _actor_logs, | |
| _img_logs, | |
| slide_width, | |
| slide_height, | |
| name_to_hierarchy, | |
| critic_template, | |
| actor_template, | |
| critic_agent, | |
| actor_agent, | |
| neg_img, | |
| pos_img, | |
| MAX_ATTEMPTS, | |
| documentation, | |
| total_input_token, | |
| total_output_token | |
| )) | |
| for future in futures: | |
| results.append(future.result()) | |
| # The code below merges the results. The method of merging depends on how | |
| # you prefer to aggregate. For a minimal approach, we’ll pick the logs from | |
| # each section, then overwrite or update them in the main dictionaries. | |
| for res in results: | |
| sec_name = res["section_name"] | |
| # Overwrite or merge logs as needed | |
| style_logs[sec_name] = res["style_logs"][sec_name] | |
| critic_logs.update(res["critic_logs"]) | |
| actor_logs.update(res["actor_logs"]) | |
| img_logs.update(res["img_logs"]) | |
| total_input_token = res["total_input_token"] | |
| total_output_token = res["total_output_token"] | |
| return style_logs, critic_logs, actor_logs, img_logs, total_input_token, total_output_token | |
| def deoverflow(args, actor_config, critic_config): | |
| total_input_token, total_output_token = 0, 0 | |
| style_ckpt = pkl.load(open(f'checkpoints/{args.model_name}_{args.poster_name}_style_ckpt_{args.index}.pkl', 'rb')) | |
| logs_ckpt = pkl.load(open(f'checkpoints/{args.model_name}_{args.poster_name}_ckpt_{args.index}.pkl', 'rb')) | |
| style_logs = style_ckpt['style_logs'] | |
| sections = list(style_logs.keys()) | |
| sections = [s for s in sections if s != 'meta'] | |
| slide_width = style_ckpt['outline']['meta']['width'] | |
| slide_height = style_ckpt['outline']['meta']['height'] | |
| content = json.load(open(f'contents/{args.model_name}_{args.poster_name}_poster_content_{args.index}.json', 'r')) | |
| outline = logs_ckpt['outline'] | |
| name_to_hierarchy = get_hierarchy(outline, 1) | |
| critic_agent_name = 'critic_overlap_agent' | |
| with open(f"prompt_templates/{critic_agent_name}.yaml", "r") as f: | |
| deoverflow_critic_config = yaml.safe_load(f) | |
| actor_agent_name = 'actor_editor_agent' | |
| with open(f"prompt_templates/{actor_agent_name}.yaml", "r") as f: | |
| deoverflow_actor_config = yaml.safe_load(f) | |
| actor_model = ModelFactory.create( | |
| model_platform=actor_config['model_platform'], | |
| model_type=actor_config['model_type'], | |
| model_config_dict=actor_config['model_config'], | |
| ) | |
| actor_sys_msg = deoverflow_actor_config['system_prompt'] | |
| actor_agent = ChatAgent( | |
| system_message=actor_sys_msg, | |
| model=actor_model, | |
| message_window_size=10, | |
| ) | |
| critic_model = ModelFactory.create( | |
| model_platform=critic_config['model_platform'], | |
| model_type=critic_config['model_type'], | |
| model_config_dict=critic_config['model_config'], | |
| ) | |
| critic_sys_msg = deoverflow_critic_config['system_prompt'] | |
| critic_agent = ChatAgent( | |
| system_message=critic_sys_msg, | |
| model=critic_model, | |
| message_window_size=None, | |
| ) | |
| jinja_env = Environment(undefined=StrictUndefined) | |
| actor_template = jinja_env.from_string(deoverflow_actor_config["template"]) | |
| critic_template = jinja_env.from_string(deoverflow_critic_config["template"]) | |
| critic_logs = {} | |
| actor_logs = {} | |
| img_logs = {} | |
| # Load neg and pos examples | |
| neg_img = Image.open('overflow_example/neg.jpg') | |
| pos_img = Image.open('overflow_example/pos.jpg') | |
| style_logs, critic_logs, actor_logs, img_logs, total_input_token, total_output_token = parallel_by_sections( | |
| sections=sections, | |
| content=content, | |
| outline=outline, | |
| style_logs=style_logs, | |
| critic_logs=critic_logs, | |
| actor_logs=actor_logs, | |
| img_logs=img_logs, | |
| slide_width=slide_width, | |
| slide_height=slide_height, | |
| name_to_hierarchy=name_to_hierarchy, | |
| critic_template=critic_template, | |
| actor_template=actor_template, | |
| critic_agent=critic_agent, | |
| actor_agent=actor_agent, | |
| neg_img=neg_img, | |
| pos_img=pos_img, | |
| MAX_ATTEMPTS=MAX_ATTEMPTS, | |
| documentation=documentation, | |
| total_input_token=total_input_token, | |
| total_output_token=total_output_token, | |
| max_workers=100, # or however many worker threads you want | |
| ) | |
| final_code = '' | |
| for section in sections: | |
| final_code += style_logs[section][-1]['code'] + '\n' | |
| run_code_with_utils(final_code, utils_functions) | |
| ppt_to_images(f'poster.pptx', 'tmp/non_overlap_preview') | |
| result_dir = f'results/{args.poster_name}/{args.model_name}/{args.index}' | |
| if not os.path.exists(result_dir): | |
| os.makedirs(result_dir) | |
| shutil.copy('poster.pptx', f'{result_dir}/non_overlap_poster.pptx') | |
| ppt_to_images(f'poster.pptx', f'{result_dir}/non_overlap_poster_preview') | |
| final_code_by_section = {} | |
| for section in sections: | |
| final_code_by_section[section] = style_logs[section][-1]['code'] | |
| non_overlap_ckpt = { | |
| 'critic_logs': critic_logs, | |
| 'actor_logs': actor_logs, | |
| 'img_logs': img_logs, | |
| 'name_to_hierarchy': name_to_hierarchy, | |
| 'final_code': final_code, | |
| 'final_code_by_section': final_code_by_section, | |
| 'total_input_token': total_input_token, | |
| 'total_output_token': total_output_token | |
| } | |
| pkl.dump(non_overlap_ckpt, open(f'checkpoints/{args.model_name}_{args.poster_name}_non_overlap_ckpt_{args.index}.pkl', 'wb')) | |
| return total_input_token, total_output_token | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--poster_name', type=str, default=None) | |
| parser.add_argument('--model_name', type=str, default='4o') | |
| parser.add_argument('--poster_path', type=str, required=True) | |
| parser.add_argument('--index', type=int, default=0) | |
| parser.add_argument('--max_retry', type=int, default=3) | |
| args = parser.parse_args() | |
| actor_config = get_agent_config(args.model_name) | |
| critic_config = get_agent_config(args.model_name) | |
| if args.poster_name is None: | |
| args.poster_name = args.poster_path.split('/')[-1].replace('.pdf', '').replace(' ', '_') | |
| input_token, output_token = deoverflow(args, actor_config, critic_config) | |
| print(f'Token consumption: {input_token} -> {output_token}') |