File size: 9,980 Bytes
d961e88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
import asyncio
from dataclasses import dataclass
from typing import Literal, Optional

from jinja2 import StrictUndefined, Template

from pptagent.llms import LLM, AsyncLLM
from pptagent.utils import get_logger, package_join, pbasename, pexists, pjoin

logger = get_logger(__name__)

LENGTHY_REWRITE_PROMPT = Template(
    open(package_join("prompts", "lengthy_rewrite.txt")).read(),
    undefined=StrictUndefined,
)


@dataclass
class Element:
    el_name: str
    content: list[str]
    description: str
    el_type: Literal["text", "image"]
    suggested_characters: int | None
    variable_length: tuple[int, int] | None
    variable_data: dict[str, list[str]] | None

    def get_schema(self):
        schema = f"Element: {self.el_name}\n"
        base_attrs = ["description", "el_type"]
        for attr in base_attrs:
            schema += f"\t{attr}: {getattr(self, attr)}\n"
        if self.el_type == "text":
            schema += f"\tsuggested_characters: {self.suggested_characters}\n"
        if self.variable_length is not None:
            schema += f"\tThe length of the element can vary between {self.variable_length[0]} and {self.variable_length[1]}\n"
        schema += f"\tThe default quantity of the element is {len(self.content)}\n"
        return schema

    @classmethod
    def from_dict(cls, el_name: str, data: dict):
        if not isinstance(data["data"], list):
            data["data"] = [data["data"]]
        if data["type"] == "text":
            suggested_characters = max(len(i) for i in data["data"])
        elif data["type"] == "image":
            suggested_characters = None
        return cls(
            el_name=el_name,
            el_type=data["type"],
            content=data["data"],
            description=data["description"],
            variable_length=data.get("variableLength", None),
            variable_data=data.get("variableData", None),
            suggested_characters=suggested_characters,
        )


@dataclass
class Layout:
    title: str
    template_id: int
    slides: list[int]
    elements: list[Element]
    vary_mapping: dict[int, int] | None  # mapping for variable elements

    @classmethod
    def from_dict(cls, title: str, data: dict):
        elements = [
            Element.from_dict(el_name, data["content_schema"][el_name])
            for el_name in data["content_schema"]
        ]
        num_vary_elements = sum((el.variable_length is not None) for el in elements)
        if num_vary_elements > 1:
            raise ValueError("Only one variable element is allowed")
        return cls(
            title=title,
            template_id=data["template_id"],
            slides=data["slides"],
            elements=elements,
            vary_mapping=data.get("vary_mapping", None),
        )

    def get_slide_id(self, data: dict):
        for el in self.elements:
            if el.variable_length is not None:
                num_vary = len(data[el.el_name]["data"])
                if num_vary < el.variable_length[0]:
                    raise ValueError(
                        f"The length of {el.el_name}: {num_vary} is less than the minimum length: {el.variable_length[0]}"
                    )
                if num_vary > el.variable_length[1]:
                    raise ValueError(
                        f"The length of {el.el_name}: {num_vary} is greater than the maximum length: {el.variable_length[1]}"
                    )
                return self.vary_mapping[str(num_vary)]
        return self.template_id

    def get_old_data(self, editor_output: Optional[dict] = None):
        if editor_output is None:
            return {el.el_name: el.content for el in self.elements}
        old_data = {}
        for el in self.elements:
            if el.variable_length is not None:
                key = str(len(editor_output[el.el_name]["data"]))
                assert (
                    key in el.variable_data
                ), f"The length of element {el.el_name} varies between {el.variable_length[0]} and {el.variable_length[1]}, but got data of length {key} which is not supported"
                old_data[el.el_name] = el.variable_data[key]
            else:
                old_data[el.el_name] = el.content
        return old_data

    def validate(self, editor_output: dict, image_dir: str):
        for el_name, el_data in editor_output.items():
            assert (
                "data" in el_data
            ), """key `data` not found in output

                    please give your output as a dict like

                    {

                        "element1": {

                            "data": ["text1", "text2"] for text elements

                            or ["/path/to/image", "..."] for image elements

                        },

                    }"""
            assert (
                el_name in self
            ), f"Element {el_name} is not a valid element, supported elements are {[el.el_name for el in self.elements]}"
            if self[el_name].el_type == "image":
                for i in range(len(el_data["data"])):
                    if pexists(pjoin(image_dir, el_data["data"][i])):
                        el_data["data"][i] = pjoin(image_dir, el_data["data"][i])
                    if not pexists(el_data["data"][i]):
                        basename = pbasename(el_data["data"][i])
                        if pexists(pjoin(image_dir, basename)):
                            el_data["data"][i] = pjoin(image_dir, basename)
                        else:
                            raise ValueError(
                                f"Image {el_data['data'][i]} not found\n"
                                "Please check the image path and use only existing images\n"
                                "Or, leave a blank list for this element"
                            )

    def validate_length(

        self, editor_output: dict, length_factor: float, language_model: LLM

    ):
        for el_name, el_data in editor_output.items():
            if self[el_name].el_type == "text":
                charater_counts = [len(i) for i in el_data["data"]]
                if (
                    max(charater_counts)
                    > self[el_name].suggested_characters * length_factor
                ):
                    el_data["data"] = language_model(
                        LENGTHY_REWRITE_PROMPT.render(
                            el_name=el_name,
                            content=el_data["data"],
                            suggested_characters=f"{self[el_name].suggested_characters} characters",
                        ),
                        return_json=True,
                    )
                    assert isinstance(
                        el_data["data"], list
                    ), f"Generated data is lengthy, expect {self[el_name].suggested_characters} characters, but got {len(el_data['data'])} characters for element {el_name}"

    async def validate_length_async(

        self, editor_output: dict, length_factor: float, language_model: AsyncLLM

    ):
        async with asyncio.TaskGroup() as tg:
            tasks = {}
            for el_name, el_data in editor_output.items():
                if self[el_name].el_type == "text":
                    charater_counts = [len(i) for i in el_data["data"]]
                    if (
                        max(charater_counts)
                        > self[el_name].suggested_characters * length_factor
                    ):
                        task = tg.create_task(
                            language_model(
                                LENGTHY_REWRITE_PROMPT.render(
                                    el_name=el_name,
                                    content=el_data["data"],
                                    suggested_characters=f"{self[el_name].suggested_characters} characters",
                                ),
                                return_json=True,
                            )
                        )
                        tasks[el_name] = task

            for el_name, task in tasks.items():
                assert isinstance(
                    editor_output[el_name]["data"], list
                ), f"Generated data is lengthy, expect {self[el_name].suggested_characters} characters, but got {len(editor_output[el_name]['data'])} characters for element {el_name}"
                new_data = await task
                logger.debug(
                    f"Lengthy rewrite for {el_name}:\n From {editor_output[el_name]['data']}\n To {new_data}"
                )
                editor_output[el_name]["data"] = new_data

    @property
    def content_schema(self):
        return "\n".join([el.get_schema() for el in self.elements])

    def remove_item(self, item: str):
        for el in self.elements:
            if item in el.content:
                el.content.remove(item)
                if len(el.content) == 0:
                    self.elements.remove(el)
                return
        else:
            raise ValueError(f"Item {item} not found in layout {self.title}")

    def __contains__(self, key: str | int):
        if isinstance(key, int):
            return key in self.slides
        elif isinstance(key, str):
            for el in self.elements:
                if el.el_name == key:
                    return True
            return False
        raise ValueError(f"Invalid key type: {type(key)}, should be str or int")

    def __getitem__(self, key: str):
        for el in self.elements:
            if el.el_name == key:
                return el
        raise ValueError(f"Element {key} not found")

    def __iter__(self):
        return iter(self.elements)

    def __len__(self):
        return len(self.elements)