Spaces:
Runtime error
Runtime error
| import asyncio | |
| from dataclasses import dataclass | |
| from typing import Literal, Optional | |
| from jinja2 import StrictUndefined, Template | |
| from pptagent.llms import LLM, AsyncLLM | |
| from pptagent.utils import get_logger, package_join, pbasename, pexists, pjoin | |
| logger = get_logger(__name__) | |
| LENGTHY_REWRITE_PROMPT = Template( | |
| open(package_join("prompts", "lengthy_rewrite.txt")).read(), | |
| undefined=StrictUndefined, | |
| ) | |
| class Element: | |
| el_name: str | |
| content: list[str] | |
| description: str | |
| el_type: Literal["text", "image"] | |
| suggested_characters: int | None | |
| variable_length: tuple[int, int] | None | |
| variable_data: dict[str, list[str]] | None | |
| def get_schema(self): | |
| schema = f"Element: {self.el_name}\n" | |
| base_attrs = ["description", "el_type"] | |
| for attr in base_attrs: | |
| schema += f"\t{attr}: {getattr(self, attr)}\n" | |
| if self.el_type == "text": | |
| schema += f"\tsuggested_characters: {self.suggested_characters}\n" | |
| if self.variable_length is not None: | |
| schema += f"\tThe length of the element can vary between {self.variable_length[0]} and {self.variable_length[1]}\n" | |
| schema += f"\tThe default quantity of the element is {len(self.content)}\n" | |
| return schema | |
| def from_dict(cls, el_name: str, data: dict): | |
| if not isinstance(data["data"], list): | |
| data["data"] = [data["data"]] | |
| if data["type"] == "text": | |
| suggested_characters = max(len(i) for i in data["data"]) | |
| elif data["type"] == "image": | |
| suggested_characters = None | |
| return cls( | |
| el_name=el_name, | |
| el_type=data["type"], | |
| content=data["data"], | |
| description=data["description"], | |
| variable_length=data.get("variableLength", None), | |
| variable_data=data.get("variableData", None), | |
| suggested_characters=suggested_characters, | |
| ) | |
| class Layout: | |
| title: str | |
| template_id: int | |
| slides: list[int] | |
| elements: list[Element] | |
| vary_mapping: dict[int, int] | None # mapping for variable elements | |
| def from_dict(cls, title: str, data: dict): | |
| elements = [ | |
| Element.from_dict(el_name, data["content_schema"][el_name]) | |
| for el_name in data["content_schema"] | |
| ] | |
| num_vary_elements = sum((el.variable_length is not None) for el in elements) | |
| if num_vary_elements > 1: | |
| raise ValueError("Only one variable element is allowed") | |
| return cls( | |
| title=title, | |
| template_id=data["template_id"], | |
| slides=data["slides"], | |
| elements=elements, | |
| vary_mapping=data.get("vary_mapping", None), | |
| ) | |
| def get_slide_id(self, data: dict): | |
| for el in self.elements: | |
| if el.variable_length is not None: | |
| num_vary = len(data[el.el_name]["data"]) | |
| if num_vary < el.variable_length[0]: | |
| raise ValueError( | |
| f"The length of {el.el_name}: {num_vary} is less than the minimum length: {el.variable_length[0]}" | |
| ) | |
| if num_vary > el.variable_length[1]: | |
| raise ValueError( | |
| f"The length of {el.el_name}: {num_vary} is greater than the maximum length: {el.variable_length[1]}" | |
| ) | |
| return self.vary_mapping[str(num_vary)] | |
| return self.template_id | |
| def get_old_data(self, editor_output: Optional[dict] = None): | |
| if editor_output is None: | |
| return {el.el_name: el.content for el in self.elements} | |
| old_data = {} | |
| for el in self.elements: | |
| if el.variable_length is not None: | |
| key = str(len(editor_output[el.el_name]["data"])) | |
| assert ( | |
| key in el.variable_data | |
| ), f"The length of element {el.el_name} varies between {el.variable_length[0]} and {el.variable_length[1]}, but got data of length {key} which is not supported" | |
| old_data[el.el_name] = el.variable_data[key] | |
| else: | |
| old_data[el.el_name] = el.content | |
| return old_data | |
| def validate(self, editor_output: dict, image_dir: str): | |
| for el_name, el_data in editor_output.items(): | |
| assert ( | |
| "data" in el_data | |
| ), """key `data` not found in output | |
| please give your output as a dict like | |
| { | |
| "element1": { | |
| "data": ["text1", "text2"] for text elements | |
| or ["/path/to/image", "..."] for image elements | |
| }, | |
| }""" | |
| assert ( | |
| el_name in self | |
| ), f"Element {el_name} is not a valid element, supported elements are {[el.el_name for el in self.elements]}" | |
| if self[el_name].el_type == "image": | |
| for i in range(len(el_data["data"])): | |
| if pexists(pjoin(image_dir, el_data["data"][i])): | |
| el_data["data"][i] = pjoin(image_dir, el_data["data"][i]) | |
| if not pexists(el_data["data"][i]): | |
| basename = pbasename(el_data["data"][i]) | |
| if pexists(pjoin(image_dir, basename)): | |
| el_data["data"][i] = pjoin(image_dir, basename) | |
| else: | |
| raise ValueError( | |
| f"Image {el_data['data'][i]} not found\n" | |
| "Please check the image path and use only existing images\n" | |
| "Or, leave a blank list for this element" | |
| ) | |
| def validate_length( | |
| self, editor_output: dict, length_factor: float, language_model: LLM | |
| ): | |
| for el_name, el_data in editor_output.items(): | |
| if self[el_name].el_type == "text": | |
| charater_counts = [len(i) for i in el_data["data"]] | |
| if ( | |
| max(charater_counts) | |
| > self[el_name].suggested_characters * length_factor | |
| ): | |
| el_data["data"] = language_model( | |
| LENGTHY_REWRITE_PROMPT.render( | |
| el_name=el_name, | |
| content=el_data["data"], | |
| suggested_characters=f"{self[el_name].suggested_characters} characters", | |
| ), | |
| return_json=True, | |
| ) | |
| assert isinstance( | |
| el_data["data"], list | |
| ), f"Generated data is lengthy, expect {self[el_name].suggested_characters} characters, but got {len(el_data['data'])} characters for element {el_name}" | |
| async def validate_length_async( | |
| self, editor_output: dict, length_factor: float, language_model: AsyncLLM | |
| ): | |
| async with asyncio.TaskGroup() as tg: | |
| tasks = {} | |
| for el_name, el_data in editor_output.items(): | |
| if self[el_name].el_type == "text": | |
| charater_counts = [len(i) for i in el_data["data"]] | |
| if ( | |
| max(charater_counts) | |
| > self[el_name].suggested_characters * length_factor | |
| ): | |
| task = tg.create_task( | |
| language_model( | |
| LENGTHY_REWRITE_PROMPT.render( | |
| el_name=el_name, | |
| content=el_data["data"], | |
| suggested_characters=f"{self[el_name].suggested_characters} characters", | |
| ), | |
| return_json=True, | |
| ) | |
| ) | |
| tasks[el_name] = task | |
| for el_name, task in tasks.items(): | |
| assert isinstance( | |
| editor_output[el_name]["data"], list | |
| ), f"Generated data is lengthy, expect {self[el_name].suggested_characters} characters, but got {len(editor_output[el_name]['data'])} characters for element {el_name}" | |
| new_data = await task | |
| logger.debug( | |
| f"Lengthy rewrite for {el_name}:\n From {editor_output[el_name]['data']}\n To {new_data}" | |
| ) | |
| editor_output[el_name]["data"] = new_data | |
| def content_schema(self): | |
| return "\n".join([el.get_schema() for el in self.elements]) | |
| def remove_item(self, item: str): | |
| for el in self.elements: | |
| if item in el.content: | |
| el.content.remove(item) | |
| if len(el.content) == 0: | |
| self.elements.remove(el) | |
| return | |
| else: | |
| raise ValueError(f"Item {item} not found in layout {self.title}") | |
| def __contains__(self, key: str | int): | |
| if isinstance(key, int): | |
| return key in self.slides | |
| elif isinstance(key, str): | |
| for el in self.elements: | |
| if el.el_name == key: | |
| return True | |
| return False | |
| raise ValueError(f"Invalid key type: {type(key)}, should be str or int") | |
| def __getitem__(self, key: str): | |
| for el in self.elements: | |
| if el.el_name == key: | |
| return el | |
| raise ValueError(f"Element {key} not found") | |
| def __iter__(self): | |
| return iter(self.elements) | |
| def __len__(self): | |
| return len(self.elements) | |