File size: 5,572 Bytes
d961e88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import asyncio
from typing import Optional

import PIL.Image

from pptagent.llms import LLM, AsyncLLM
from pptagent.presentation import Picture, Presentation
from pptagent.utils import Config, get_logger, package_join, pbasename, pjoin

logger = get_logger(__name__)


class ImageLabler:
    """

    A class to extract images information, including caption, size, and appearance times in a presentation.

    """

    def __init__(self, presentation: Presentation, config: Config):
        """

        Initialize the ImageLabler.



        Args:

            presentation (Presentation): The presentation object.

            config (Config): The configuration object.

        """
        self.presentation = presentation
        self.slide_area = presentation.slide_width.pt * presentation.slide_height.pt
        self.image_stats = {}
        self.config = config
        self.collect_images()

    def apply_stats(self, image_stats: Optional[dict[str, dict]] = None):
        """

        Apply image captions to the presentation.

        """
        if image_stats is None:
            image_stats = self.image_stats

        for slide in self.presentation.slides:
            for shape in slide.shape_filter(Picture):
                if shape.caption is None:
                    caption = image_stats[pbasename(shape.img_path)]["caption"]
                    shape.caption = max(caption.split("\n"), key=len)

    async def caption_images_async(self, vision_model: AsyncLLM):
        """

        Generate captions for images in the presentation asynchronously.



        Args:

            vision_model (AsyncLLM): The async vision model to use for captioning.



        Returns:

            dict: Dictionary containing image stats with captions.

        """
        assert isinstance(
            vision_model, AsyncLLM
        ), "vision_model must be an AsyncLLM instance"
        caption_prompt = open(package_join("prompts", "caption.txt")).read()

        async with asyncio.TaskGroup() as tg:
            for image, stats in self.image_stats.items():
                if "caption" not in stats:
                    task = tg.create_task(
                        vision_model(
                            caption_prompt,
                            pjoin(self.config.IMAGE_DIR, image),
                        )
                    )
                    task.add_done_callback(
                        lambda t, image=image: (
                            self.image_stats[image].update({"caption": t.result()}),
                            logger.debug("captioned %s: %s", image, t.result()),
                        )
                    )

        self.apply_stats()
        return self.image_stats

    def caption_images(self, vision_model: LLM):
        """

        Generate captions for images in the presentation.



        Args:

            vision_model (LLM): The vision model to use for captioning.



        Returns:

            dict: Dictionary containing image stats with captions.

        """
        assert isinstance(vision_model, LLM), "vision_model must be an LLM instance"
        caption_prompt = open(package_join("prompts", "caption.txt")).read()
        for image, stats in self.image_stats.items():
            if "caption" not in stats:
                stats["caption"] = vision_model(
                    caption_prompt, pjoin(self.config.IMAGE_DIR, image)
                )
                logger.debug("captioned %s: %s", image, stats["caption"])
        self.apply_stats()
        return self.image_stats

    def collect_images(self):
        """

        Collect images from the presentation and gather other information.

        """
        for slide_index, slide in enumerate(self.presentation.slides):
            for shape in slide.shape_filter(Picture):
                image_path = pbasename(shape.img_path)
                if image_path == "pic_placeholder.png":
                    continue
                if image_path not in self.image_stats:
                    size = PIL.Image.open(pjoin(self.config.IMAGE_DIR, image_path)).size
                    self.image_stats[image_path] = {
                        "size": size,
                        "appear_times": 0,
                        "slide_numbers": set(),
                        "relative_area": shape.area / self.slide_area * 100,
                    }
                self.image_stats[image_path]["appear_times"] += 1
                self.image_stats[image_path]["slide_numbers"].add(slide_index + 1)
        for image_path, stats in self.image_stats.items():
            stats["slide_numbers"] = sorted(list(stats["slide_numbers"]))
            ranges = self._find_ranges(stats["slide_numbers"])
            top_ranges = sorted(ranges, key=lambda x: x[1] - x[0], reverse=True)[:3]
            top_ranges_str = ", ".join(
                [f"{r[0]}-{r[1]}" if r[0] != r[1] else f"{r[0]}" for r in top_ranges]
            )
            stats["top_ranges_str"] = top_ranges_str

    def _find_ranges(self, numbers):
        """

        Find consecutive ranges in a list of numbers.

        """
        ranges = []
        start = numbers[0]
        end = numbers[0]
        for num in numbers[1:]:
            if num == end + 1:
                end = num
            else:
                ranges.append((start, end))
                start = num
                end = num
        ranges.append((start, end))
        return ranges