slidegent / src /ppt /deck_generator.py
jomasego's picture
Upload folder using huggingface_hub
ad0ebea verified
Raw
History Blame Contribute Delete
5.85 kB
from __future__ import annotations
import json
import re
import tempfile
from pathlib import Path
from pptx import Presentation
from pptx.dml.color import RGBColor
from pptx.util import Pt
from src.ppt.placeholder_mapper import choose_layout_index
from src.utils.validators import validate_slides, validate_template
def _extract_json_array_block(raw_text: str) -> str:
text = raw_text.strip()
# Remove fenced markdown wrappers if present.
text = re.sub(r"^```(?:json)?", "", text, flags=re.IGNORECASE).strip()
text = re.sub(r"```$", "", text).strip()
if text.startswith("[") and text.endswith("]"):
return text
match = re.search(r"\[[\s\S]*\]", text)
if match:
return match.group(0)
raise ValueError("No JSON array block found in model output.")
def parse_slides_payload(slides_payload: str | list[dict]) -> list[dict]:
if isinstance(slides_payload, list):
parsed = slides_payload
elif isinstance(slides_payload, str):
block = _extract_json_array_block(slides_payload)
try:
parsed = json.loads(block)
except json.JSONDecodeError as exc:
raise ValueError(f"Malformed slide JSON: {exc}") from exc
else:
raise ValueError("slides_payload must be a JSON string or list of dict objects.")
normalized = []
for item in parsed:
if not isinstance(item, dict):
continue
title = str(item.get("title", "Untitled Slide")).strip() or "Untitled Slide"
bullets = item.get("bullets", item.get("content", []))
if isinstance(bullets, str):
bullets = [bullets]
if not isinstance(bullets, list):
bullets = []
bullets = [str(b).strip() for b in bullets if str(b).strip()]
layout_type = str(item.get("layout_type", "content_slide"))
normalized.append(
{
"title": title,
"bullets": bullets,
"layout_type": layout_type,
}
)
validate_slides(normalized)
return normalized
def _apply_run_style(run, style_profile: dict):
font_style = (style_profile or {}).get("font", {})
if not font_style:
return
if font_style.get("name"):
run.font.name = font_style["name"]
if font_style.get("size_pt"):
run.font.size = Pt(font_style["size_pt"])
if font_style.get("bold") is not None:
run.font.bold = bool(font_style["bold"])
if font_style.get("italic") is not None:
run.font.italic = bool(font_style["italic"])
color = font_style.get("color_rgb") or (style_profile or {}).get("dominant_text_color")
if color and isinstance(color, (list, tuple)) and len(color) == 3:
run.font.color.rgb = RGBColor(int(color[0]), int(color[1]), int(color[2]))
def _remove_all_slides(prs: Presentation):
# Remove slide relationship + slide ID to avoid duplicate part warnings on save.
for idx in range(len(prs.slides) - 1, -1, -1):
slide_id = prs.slides._sldIdLst[idx] # pylint: disable=protected-access
rel_id = slide_id.rId
prs.part.drop_rel(rel_id)
del prs.slides._sldIdLst[idx] # pylint: disable=protected-access
def generate_presentation(
slides_payload: str | list[dict],
template_path: str | Path,
output_path: str | Path | None = None,
style_profile: dict | None = None,
) -> Presentation:
validated_template = validate_template(template_path)
slides = parse_slides_payload(slides_payload)
prs = Presentation(str(validated_template))
_remove_all_slides(prs)
for slide_data in slides:
layout_idx = choose_layout_index(
slide_data.get("layout_type", "content_slide"),
style_profile or {},
len(prs.slide_layouts),
)
slide = prs.slides.add_slide(prs.slide_layouts[layout_idx])
title_shape = slide.shapes.title
if title_shape and title_shape.text_frame:
title_shape.text_frame.clear()
run = title_shape.text_frame.paragraphs[0].add_run()
run.text = slide_data["title"]
_apply_run_style(run, style_profile or {})
body_shape = None
for shape in slide.placeholders:
if shape is not title_shape and getattr(shape, "has_text_frame", False):
body_shape = shape
break
if body_shape and body_shape.text_frame:
body_shape.text_frame.clear()
for idx, bullet in enumerate(slide_data["bullets"]):
paragraph = (
body_shape.text_frame.paragraphs[0]
if idx == 0
else body_shape.text_frame.add_paragraph()
)
run = paragraph.add_run()
run.text = bullet
_apply_run_style(run, style_profile or {})
if output_path:
prs.save(str(output_path))
return prs
def compile_presentation(
slides_payload: str | list[dict],
template_path: str | Path,
style_profile: dict | None = None,
output_path: str | Path | None = None,
) -> str:
destination = Path(output_path) if output_path else Path(tempfile.NamedTemporaryFile(suffix=".pptx", delete=False).name)
generate_presentation(
slides_payload=slides_payload,
template_path=template_path,
output_path=destination,
style_profile=style_profile,
)
return str(destination)
def create_deck(template_data: dict, slide_data: str | list[dict], template_path: str | Path, output_path: str | Path | None = None):
"""Backward-compatible wrapper used by service layer."""
return compile_presentation(
slides_payload=slide_data,
template_path=template_path,
style_profile=template_data,
output_path=output_path,
)