Spaces:
Runtime error
Runtime error
File size: 9,980 Bytes
d961e88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 |
import asyncio
from dataclasses import dataclass
from typing import Literal, Optional
from jinja2 import StrictUndefined, Template
from pptagent.llms import LLM, AsyncLLM
from pptagent.utils import get_logger, package_join, pbasename, pexists, pjoin
logger = get_logger(__name__)
LENGTHY_REWRITE_PROMPT = Template(
open(package_join("prompts", "lengthy_rewrite.txt")).read(),
undefined=StrictUndefined,
)
@dataclass
class Element:
el_name: str
content: list[str]
description: str
el_type: Literal["text", "image"]
suggested_characters: int | None
variable_length: tuple[int, int] | None
variable_data: dict[str, list[str]] | None
def get_schema(self):
schema = f"Element: {self.el_name}\n"
base_attrs = ["description", "el_type"]
for attr in base_attrs:
schema += f"\t{attr}: {getattr(self, attr)}\n"
if self.el_type == "text":
schema += f"\tsuggested_characters: {self.suggested_characters}\n"
if self.variable_length is not None:
schema += f"\tThe length of the element can vary between {self.variable_length[0]} and {self.variable_length[1]}\n"
schema += f"\tThe default quantity of the element is {len(self.content)}\n"
return schema
@classmethod
def from_dict(cls, el_name: str, data: dict):
if not isinstance(data["data"], list):
data["data"] = [data["data"]]
if data["type"] == "text":
suggested_characters = max(len(i) for i in data["data"])
elif data["type"] == "image":
suggested_characters = None
return cls(
el_name=el_name,
el_type=data["type"],
content=data["data"],
description=data["description"],
variable_length=data.get("variableLength", None),
variable_data=data.get("variableData", None),
suggested_characters=suggested_characters,
)
@dataclass
class Layout:
title: str
template_id: int
slides: list[int]
elements: list[Element]
vary_mapping: dict[int, int] | None # mapping for variable elements
@classmethod
def from_dict(cls, title: str, data: dict):
elements = [
Element.from_dict(el_name, data["content_schema"][el_name])
for el_name in data["content_schema"]
]
num_vary_elements = sum((el.variable_length is not None) for el in elements)
if num_vary_elements > 1:
raise ValueError("Only one variable element is allowed")
return cls(
title=title,
template_id=data["template_id"],
slides=data["slides"],
elements=elements,
vary_mapping=data.get("vary_mapping", None),
)
def get_slide_id(self, data: dict):
for el in self.elements:
if el.variable_length is not None:
num_vary = len(data[el.el_name]["data"])
if num_vary < el.variable_length[0]:
raise ValueError(
f"The length of {el.el_name}: {num_vary} is less than the minimum length: {el.variable_length[0]}"
)
if num_vary > el.variable_length[1]:
raise ValueError(
f"The length of {el.el_name}: {num_vary} is greater than the maximum length: {el.variable_length[1]}"
)
return self.vary_mapping[str(num_vary)]
return self.template_id
def get_old_data(self, editor_output: Optional[dict] = None):
if editor_output is None:
return {el.el_name: el.content for el in self.elements}
old_data = {}
for el in self.elements:
if el.variable_length is not None:
key = str(len(editor_output[el.el_name]["data"]))
assert (
key in el.variable_data
), f"The length of element {el.el_name} varies between {el.variable_length[0]} and {el.variable_length[1]}, but got data of length {key} which is not supported"
old_data[el.el_name] = el.variable_data[key]
else:
old_data[el.el_name] = el.content
return old_data
def validate(self, editor_output: dict, image_dir: str):
for el_name, el_data in editor_output.items():
assert (
"data" in el_data
), """key `data` not found in output
please give your output as a dict like
{
"element1": {
"data": ["text1", "text2"] for text elements
or ["/path/to/image", "..."] for image elements
},
}"""
assert (
el_name in self
), f"Element {el_name} is not a valid element, supported elements are {[el.el_name for el in self.elements]}"
if self[el_name].el_type == "image":
for i in range(len(el_data["data"])):
if pexists(pjoin(image_dir, el_data["data"][i])):
el_data["data"][i] = pjoin(image_dir, el_data["data"][i])
if not pexists(el_data["data"][i]):
basename = pbasename(el_data["data"][i])
if pexists(pjoin(image_dir, basename)):
el_data["data"][i] = pjoin(image_dir, basename)
else:
raise ValueError(
f"Image {el_data['data'][i]} not found\n"
"Please check the image path and use only existing images\n"
"Or, leave a blank list for this element"
)
def validate_length(
self, editor_output: dict, length_factor: float, language_model: LLM
):
for el_name, el_data in editor_output.items():
if self[el_name].el_type == "text":
charater_counts = [len(i) for i in el_data["data"]]
if (
max(charater_counts)
> self[el_name].suggested_characters * length_factor
):
el_data["data"] = language_model(
LENGTHY_REWRITE_PROMPT.render(
el_name=el_name,
content=el_data["data"],
suggested_characters=f"{self[el_name].suggested_characters} characters",
),
return_json=True,
)
assert isinstance(
el_data["data"], list
), f"Generated data is lengthy, expect {self[el_name].suggested_characters} characters, but got {len(el_data['data'])} characters for element {el_name}"
async def validate_length_async(
self, editor_output: dict, length_factor: float, language_model: AsyncLLM
):
async with asyncio.TaskGroup() as tg:
tasks = {}
for el_name, el_data in editor_output.items():
if self[el_name].el_type == "text":
charater_counts = [len(i) for i in el_data["data"]]
if (
max(charater_counts)
> self[el_name].suggested_characters * length_factor
):
task = tg.create_task(
language_model(
LENGTHY_REWRITE_PROMPT.render(
el_name=el_name,
content=el_data["data"],
suggested_characters=f"{self[el_name].suggested_characters} characters",
),
return_json=True,
)
)
tasks[el_name] = task
for el_name, task in tasks.items():
assert isinstance(
editor_output[el_name]["data"], list
), f"Generated data is lengthy, expect {self[el_name].suggested_characters} characters, but got {len(editor_output[el_name]['data'])} characters for element {el_name}"
new_data = await task
logger.debug(
f"Lengthy rewrite for {el_name}:\n From {editor_output[el_name]['data']}\n To {new_data}"
)
editor_output[el_name]["data"] = new_data
@property
def content_schema(self):
return "\n".join([el.get_schema() for el in self.elements])
def remove_item(self, item: str):
for el in self.elements:
if item in el.content:
el.content.remove(item)
if len(el.content) == 0:
self.elements.remove(el)
return
else:
raise ValueError(f"Item {item} not found in layout {self.title}")
def __contains__(self, key: str | int):
if isinstance(key, int):
return key in self.slides
elif isinstance(key, str):
for el in self.elements:
if el.el_name == key:
return True
return False
raise ValueError(f"Invalid key type: {type(key)}, should be str or int")
def __getitem__(self, key: str):
for el in self.elements:
if el.el_name == key:
return el
raise ValueError(f"Element {key} not found")
def __iter__(self):
return iter(self.elements)
def __len__(self):
return len(self.elements)
|