Spaces:
Runtime error
Runtime error
File size: 5,572 Bytes
d961e88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import asyncio
from typing import Optional
import PIL.Image
from pptagent.llms import LLM, AsyncLLM
from pptagent.presentation import Picture, Presentation
from pptagent.utils import Config, get_logger, package_join, pbasename, pjoin
logger = get_logger(__name__)
class ImageLabler:
"""
A class to extract images information, including caption, size, and appearance times in a presentation.
"""
def __init__(self, presentation: Presentation, config: Config):
"""
Initialize the ImageLabler.
Args:
presentation (Presentation): The presentation object.
config (Config): The configuration object.
"""
self.presentation = presentation
self.slide_area = presentation.slide_width.pt * presentation.slide_height.pt
self.image_stats = {}
self.config = config
self.collect_images()
def apply_stats(self, image_stats: Optional[dict[str, dict]] = None):
"""
Apply image captions to the presentation.
"""
if image_stats is None:
image_stats = self.image_stats
for slide in self.presentation.slides:
for shape in slide.shape_filter(Picture):
if shape.caption is None:
caption = image_stats[pbasename(shape.img_path)]["caption"]
shape.caption = max(caption.split("\n"), key=len)
async def caption_images_async(self, vision_model: AsyncLLM):
"""
Generate captions for images in the presentation asynchronously.
Args:
vision_model (AsyncLLM): The async vision model to use for captioning.
Returns:
dict: Dictionary containing image stats with captions.
"""
assert isinstance(
vision_model, AsyncLLM
), "vision_model must be an AsyncLLM instance"
caption_prompt = open(package_join("prompts", "caption.txt")).read()
async with asyncio.TaskGroup() as tg:
for image, stats in self.image_stats.items():
if "caption" not in stats:
task = tg.create_task(
vision_model(
caption_prompt,
pjoin(self.config.IMAGE_DIR, image),
)
)
task.add_done_callback(
lambda t, image=image: (
self.image_stats[image].update({"caption": t.result()}),
logger.debug("captioned %s: %s", image, t.result()),
)
)
self.apply_stats()
return self.image_stats
def caption_images(self, vision_model: LLM):
"""
Generate captions for images in the presentation.
Args:
vision_model (LLM): The vision model to use for captioning.
Returns:
dict: Dictionary containing image stats with captions.
"""
assert isinstance(vision_model, LLM), "vision_model must be an LLM instance"
caption_prompt = open(package_join("prompts", "caption.txt")).read()
for image, stats in self.image_stats.items():
if "caption" not in stats:
stats["caption"] = vision_model(
caption_prompt, pjoin(self.config.IMAGE_DIR, image)
)
logger.debug("captioned %s: %s", image, stats["caption"])
self.apply_stats()
return self.image_stats
def collect_images(self):
"""
Collect images from the presentation and gather other information.
"""
for slide_index, slide in enumerate(self.presentation.slides):
for shape in slide.shape_filter(Picture):
image_path = pbasename(shape.img_path)
if image_path == "pic_placeholder.png":
continue
if image_path not in self.image_stats:
size = PIL.Image.open(pjoin(self.config.IMAGE_DIR, image_path)).size
self.image_stats[image_path] = {
"size": size,
"appear_times": 0,
"slide_numbers": set(),
"relative_area": shape.area / self.slide_area * 100,
}
self.image_stats[image_path]["appear_times"] += 1
self.image_stats[image_path]["slide_numbers"].add(slide_index + 1)
for image_path, stats in self.image_stats.items():
stats["slide_numbers"] = sorted(list(stats["slide_numbers"]))
ranges = self._find_ranges(stats["slide_numbers"])
top_ranges = sorted(ranges, key=lambda x: x[1] - x[0], reverse=True)[:3]
top_ranges_str = ", ".join(
[f"{r[0]}-{r[1]}" if r[0] != r[1] else f"{r[0]}" for r in top_ranges]
)
stats["top_ranges_str"] = top_ranges_str
def _find_ranges(self, numbers):
"""
Find consecutive ranges in a list of numbers.
"""
ranges = []
start = numbers[0]
end = numbers[0]
for num in numbers[1:]:
if num == end + 1:
end = num
else:
ranges.append((start, end))
start = num
end = num
ranges.append((start, end))
return ranges
|