sjw712's picture
Upload 208 files
d961e88 verified
import asyncio
from dataclasses import dataclass
from typing import Literal, Optional
from jinja2 import StrictUndefined, Template
from pptagent.llms import LLM, AsyncLLM
from pptagent.utils import get_logger, package_join, pbasename, pexists, pjoin
logger = get_logger(__name__)
LENGTHY_REWRITE_PROMPT = Template(
open(package_join("prompts", "lengthy_rewrite.txt")).read(),
undefined=StrictUndefined,
)
@dataclass
class Element:
el_name: str
content: list[str]
description: str
el_type: Literal["text", "image"]
suggested_characters: int | None
variable_length: tuple[int, int] | None
variable_data: dict[str, list[str]] | None
def get_schema(self):
schema = f"Element: {self.el_name}\n"
base_attrs = ["description", "el_type"]
for attr in base_attrs:
schema += f"\t{attr}: {getattr(self, attr)}\n"
if self.el_type == "text":
schema += f"\tsuggested_characters: {self.suggested_characters}\n"
if self.variable_length is not None:
schema += f"\tThe length of the element can vary between {self.variable_length[0]} and {self.variable_length[1]}\n"
schema += f"\tThe default quantity of the element is {len(self.content)}\n"
return schema
@classmethod
def from_dict(cls, el_name: str, data: dict):
if not isinstance(data["data"], list):
data["data"] = [data["data"]]
if data["type"] == "text":
suggested_characters = max(len(i) for i in data["data"])
elif data["type"] == "image":
suggested_characters = None
return cls(
el_name=el_name,
el_type=data["type"],
content=data["data"],
description=data["description"],
variable_length=data.get("variableLength", None),
variable_data=data.get("variableData", None),
suggested_characters=suggested_characters,
)
@dataclass
class Layout:
title: str
template_id: int
slides: list[int]
elements: list[Element]
vary_mapping: dict[int, int] | None # mapping for variable elements
@classmethod
def from_dict(cls, title: str, data: dict):
elements = [
Element.from_dict(el_name, data["content_schema"][el_name])
for el_name in data["content_schema"]
]
num_vary_elements = sum((el.variable_length is not None) for el in elements)
if num_vary_elements > 1:
raise ValueError("Only one variable element is allowed")
return cls(
title=title,
template_id=data["template_id"],
slides=data["slides"],
elements=elements,
vary_mapping=data.get("vary_mapping", None),
)
def get_slide_id(self, data: dict):
for el in self.elements:
if el.variable_length is not None:
num_vary = len(data[el.el_name]["data"])
if num_vary < el.variable_length[0]:
raise ValueError(
f"The length of {el.el_name}: {num_vary} is less than the minimum length: {el.variable_length[0]}"
)
if num_vary > el.variable_length[1]:
raise ValueError(
f"The length of {el.el_name}: {num_vary} is greater than the maximum length: {el.variable_length[1]}"
)
return self.vary_mapping[str(num_vary)]
return self.template_id
def get_old_data(self, editor_output: Optional[dict] = None):
if editor_output is None:
return {el.el_name: el.content for el in self.elements}
old_data = {}
for el in self.elements:
if el.variable_length is not None:
key = str(len(editor_output[el.el_name]["data"]))
assert (
key in el.variable_data
), f"The length of element {el.el_name} varies between {el.variable_length[0]} and {el.variable_length[1]}, but got data of length {key} which is not supported"
old_data[el.el_name] = el.variable_data[key]
else:
old_data[el.el_name] = el.content
return old_data
def validate(self, editor_output: dict, image_dir: str):
for el_name, el_data in editor_output.items():
assert (
"data" in el_data
), """key `data` not found in output
please give your output as a dict like
{
"element1": {
"data": ["text1", "text2"] for text elements
or ["/path/to/image", "..."] for image elements
},
}"""
assert (
el_name in self
), f"Element {el_name} is not a valid element, supported elements are {[el.el_name for el in self.elements]}"
if self[el_name].el_type == "image":
for i in range(len(el_data["data"])):
if pexists(pjoin(image_dir, el_data["data"][i])):
el_data["data"][i] = pjoin(image_dir, el_data["data"][i])
if not pexists(el_data["data"][i]):
basename = pbasename(el_data["data"][i])
if pexists(pjoin(image_dir, basename)):
el_data["data"][i] = pjoin(image_dir, basename)
else:
raise ValueError(
f"Image {el_data['data'][i]} not found\n"
"Please check the image path and use only existing images\n"
"Or, leave a blank list for this element"
)
def validate_length(
self, editor_output: dict, length_factor: float, language_model: LLM
):
for el_name, el_data in editor_output.items():
if self[el_name].el_type == "text":
charater_counts = [len(i) for i in el_data["data"]]
if (
max(charater_counts)
> self[el_name].suggested_characters * length_factor
):
el_data["data"] = language_model(
LENGTHY_REWRITE_PROMPT.render(
el_name=el_name,
content=el_data["data"],
suggested_characters=f"{self[el_name].suggested_characters} characters",
),
return_json=True,
)
assert isinstance(
el_data["data"], list
), f"Generated data is lengthy, expect {self[el_name].suggested_characters} characters, but got {len(el_data['data'])} characters for element {el_name}"
async def validate_length_async(
self, editor_output: dict, length_factor: float, language_model: AsyncLLM
):
async with asyncio.TaskGroup() as tg:
tasks = {}
for el_name, el_data in editor_output.items():
if self[el_name].el_type == "text":
charater_counts = [len(i) for i in el_data["data"]]
if (
max(charater_counts)
> self[el_name].suggested_characters * length_factor
):
task = tg.create_task(
language_model(
LENGTHY_REWRITE_PROMPT.render(
el_name=el_name,
content=el_data["data"],
suggested_characters=f"{self[el_name].suggested_characters} characters",
),
return_json=True,
)
)
tasks[el_name] = task
for el_name, task in tasks.items():
assert isinstance(
editor_output[el_name]["data"], list
), f"Generated data is lengthy, expect {self[el_name].suggested_characters} characters, but got {len(editor_output[el_name]['data'])} characters for element {el_name}"
new_data = await task
logger.debug(
f"Lengthy rewrite for {el_name}:\n From {editor_output[el_name]['data']}\n To {new_data}"
)
editor_output[el_name]["data"] = new_data
@property
def content_schema(self):
return "\n".join([el.get_schema() for el in self.elements])
def remove_item(self, item: str):
for el in self.elements:
if item in el.content:
el.content.remove(item)
if len(el.content) == 0:
self.elements.remove(el)
return
else:
raise ValueError(f"Item {item} not found in layout {self.title}")
def __contains__(self, key: str | int):
if isinstance(key, int):
return key in self.slides
elif isinstance(key, str):
for el in self.elements:
if el.el_name == key:
return True
return False
raise ValueError(f"Invalid key type: {type(key)}, should be str or int")
def __getitem__(self, key: str):
for el in self.elements:
if el.el_name == key:
return el
raise ValueError(f"Element {key} not found")
def __iter__(self):
return iter(self.elements)
def __len__(self):
return len(self.elements)