Spaces:
Sleeping
Sleeping
| """ | |
| Image generation module for creating infographic images | |
| """ | |
| import io | |
| import os | |
| import logging | |
| from typing import Dict, List, Tuple, Optional | |
| from PIL import Image, ImageDraw, ImageFont, ImageFilter | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as mpatches | |
| from matplotlib import font_manager | |
| import numpy as np | |
| logger = logging.getLogger(__name__) | |
| class ImageGenerator: | |
| """Generate infographic images from layout data""" | |
| def __init__(self): | |
| """Initialize image generator""" | |
| self.default_font_path = self._get_default_font() | |
| self.generated_images = [] | |
| logger.info("Image generator initialized") | |
| def create_infographic(self, layout_data: Dict) -> str: | |
| """ | |
| Create infographic image from layout data | |
| Args: | |
| layout_data: Complete layout specification | |
| Returns: | |
| Path to generated image file | |
| """ | |
| try: | |
| canvas = layout_data.get('canvas', {}) | |
| elements = layout_data.get('elements', []) | |
| # Create PIL image | |
| img = Image.new('RGB', | |
| (canvas.get('width', 1080), canvas.get('height', 1920)), | |
| canvas.get('background', '#ffffff')) | |
| draw = ImageDraw.Draw(img) | |
| # Draw each element | |
| for element in elements: | |
| self._draw_element(draw, element, img) | |
| # Apply post-processing effects | |
| img = self._apply_effects(img, layout_data) | |
| # Save image | |
| output_path = f"/tmp/infographic_{id(layout_data)}.png" | |
| img.save(output_path, 'PNG', quality=95, dpi=(300, 300)) | |
| logger.info(f"Infographic generated successfully: {output_path}") | |
| self.generated_images.append(output_path) | |
| return output_path | |
| except Exception as e: | |
| logger.error(f"Failed to generate infographic: {e}") | |
| return self._create_error_image() | |
| def _draw_element(self, draw: ImageDraw.Draw, element: Dict, img: Image.Image): | |
| """Draw individual element on the canvas""" | |
| element_type = element.get('type', 'text') | |
| position = element.get('position', {'x': 0, 'y': 0}) | |
| size = element.get('size', {'width': 100, 'height': 50}) | |
| styling = element.get('styling', {}) | |
| content = element.get('content', '') | |
| if element_type == 'title': | |
| self._draw_title(draw, content, position, size, styling) | |
| elif element_type == 'section': | |
| self._draw_section(draw, content, position, size, styling) | |
| elif element_type == 'icon': | |
| self._draw_icon(draw, content, position, size, styling) | |
| elif element_type == 'chart': | |
| self._draw_chart(draw, img, content, position, size, styling) | |
| elif element_type == 'divider': | |
| self._draw_divider(draw, position, size, styling) | |
| def _draw_title(self, draw: ImageDraw.Draw, text: str, position: Dict, size: Dict, styling: Dict): | |
| """Draw title element""" | |
| font_info = styling.get('font', ('Arial', 32, 'bold')) | |
| color = styling.get('color', '#2c3e50') | |
| alignment = styling.get('alignment', 'left') | |
| try: | |
| font = ImageFont.truetype(self.default_font_path, font_info[1]) | |
| except: | |
| font = ImageFont.load_default() | |
| # Calculate text position for alignment | |
| bbox = draw.textbbox((0, 0), text, font=font) | |
| text_width = bbox[2] - bbox[0] | |
| text_height = bbox[3] - bbox[1] | |
| if alignment == 'center': | |
| x = position['x'] + (size['width'] - text_width) // 2 | |
| elif alignment == 'right': | |
| x = position['x'] + size['width'] - text_width | |
| else: | |
| x = position['x'] | |
| y = position['y'] | |
| # Add background if specified | |
| bg_color = styling.get('background_color') | |
| if bg_color and bg_color != 'transparent': | |
| draw.rectangle([ | |
| x - 10, y - 5, | |
| x + text_width + 10, y + text_height + 5 | |
| ], fill=bg_color) | |
| # Draw text | |
| draw.text((x, y), text, fill=color, font=font) | |
| def _draw_section(self, draw: ImageDraw.Draw, text: str, position: Dict, size: Dict, styling: Dict): | |
| """Draw section element""" | |
| font_info = styling.get('font', ('Arial', 16, 'normal')) | |
| color = styling.get('color', '#2c3e50') | |
| bg_color = styling.get('background_color', 'transparent') | |
| border_radius = styling.get('border_radius', 0) | |
| padding = styling.get('padding', 20) | |
| x, y = position['x'], position['y'] | |
| width, height = size['width'], size['height'] | |
| # Draw background | |
| if bg_color and bg_color != 'transparent': | |
| if border_radius > 0: | |
| self._draw_rounded_rectangle(draw, [x, y, x + width, y + height], bg_color, border_radius) | |
| else: | |
| draw.rectangle([x, y, x + width, y + height], fill=bg_color) | |
| # Draw text with wrapping | |
| try: | |
| font = ImageFont.truetype(self.default_font_path, font_info[1]) | |
| except: | |
| font = ImageFont.load_default() | |
| # Text wrapping | |
| wrapped_text = self._wrap_text(text, font, width - padding * 2) | |
| text_y = y + padding | |
| for line in wrapped_text: | |
| draw.text((x + padding, text_y), line, fill=color, font=font) | |
| text_y += font_info[1] + 6 | |
| def _draw_icon(self, draw: ImageDraw.Draw, description: str, position: Dict, size: Dict, styling: Dict): | |
| """Draw icon placeholder (simplified implementation)""" | |
| color = styling.get('color', '#3498db') | |
| icon_size = min(size['width'], size['height']) | |
| # Draw simple geometric shape as icon placeholder | |
| center_x = position['x'] + size['width'] // 2 | |
| center_y = position['y'] + size['height'] // 2 | |
| radius = icon_size // 3 | |
| # Different shapes based on description keywords | |
| if any(word in description.lower() for word in ['chart', 'data', 'graph']): | |
| # Draw bar chart icon | |
| bar_width = radius // 2 | |
| for i in range(3): | |
| bar_height = radius * (0.5 + i * 0.3) | |
| bar_x = center_x - radius + i * bar_width | |
| bar_y = center_y + radius - bar_height | |
| draw.rectangle([bar_x, bar_y, bar_x + bar_width - 2, center_y + radius], fill=color) | |
| elif any(word in description.lower() for word in ['process', 'flow', 'step']): | |
| # Draw arrow icon | |
| draw.polygon([ | |
| (center_x - radius, center_y), | |
| (center_x, center_y - radius//2), | |
| (center_x + radius, center_y), | |
| (center_x, center_y + radius//2) | |
| ], fill=color) | |
| else: | |
| # Draw circle icon | |
| draw.ellipse([ | |
| center_x - radius, center_y - radius, | |
| center_x + radius, center_y + radius | |
| ], fill=color) | |
| def _draw_chart(self, draw: ImageDraw.Draw, img: Image.Image, description: str, | |
| position: Dict, size: Dict, styling: Dict): | |
| """Draw chart element using matplotlib""" | |
| try: | |
| # Create matplotlib figure | |
| fig, ax = plt.subplots(figsize=(size['width']/100, size['height']/100), dpi=100) | |
| # Sample data for demonstration | |
| categories = ['A', 'B', 'C', 'D'] | |
| values = [23, 45, 56, 78] | |
| colors = [styling.get('color', '#3498db')] * len(categories) | |
| # Create simple bar chart | |
| ax.bar(categories, values, color=colors) | |
| ax.set_title(description[:30], fontsize=12) | |
| ax.set_facecolor('white') | |
| fig.patch.set_facecolor('white') | |
| # Convert matplotlib figure to PIL image | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', bbox_inches='tight', dpi=100) | |
| buf.seek(0) | |
| chart_img = Image.open(buf) | |
| plt.close(fig) | |
| # Resize and paste onto main image | |
| chart_img = chart_img.resize((size['width'], size['height'])) | |
| img.paste(chart_img, (position['x'], position['y'])) | |
| buf.close() | |
| except Exception as e: | |
| logger.error(f"Failed to draw chart: {e}") | |
| # Draw placeholder rectangle | |
| draw.rectangle([ | |
| position['x'], position['y'], | |
| position['x'] + size['width'], position['y'] + size['height'] | |
| ], outline=styling.get('color', '#3498db'), width=2) | |
| def _draw_divider(self, draw: ImageDraw.Draw, position: Dict, size: Dict, styling: Dict): | |
| """Draw divider line""" | |
| color = styling.get('color', '#bdc3c7') | |
| thickness = styling.get('thickness', 2) | |
| draw.line([ | |
| position['x'], position['y'] + size['height'] // 2, | |
| position['x'] + size['width'], position['y'] + size['height'] // 2 | |
| ], fill=color, width=thickness) | |
| def _draw_rounded_rectangle(self, draw: ImageDraw.Draw, coords: List[int], | |
| fill_color: str, radius: int): | |
| """Draw rounded rectangle""" | |
| x1, y1, x2, y2 = coords | |
| # Draw main rectangle | |
| draw.rectangle([x1 + radius, y1, x2 - radius, y2], fill=fill_color) | |
| draw.rectangle([x1, y1 + radius, x2, y2 - radius], fill=fill_color) | |
| # Draw corners | |
| draw.pieslice([x1, y1, x1 + 2*radius, y1 + 2*radius], 180, 270, fill=fill_color) | |
| draw.pieslice([x2 - 2*radius, y1, x2, y1 + 2*radius], 270, 360, fill=fill_color) | |
| draw.pieslice([x1, y2 - 2*radius, x1 + 2*radius, y2], 90, 180, fill=fill_color) | |
| draw.pieslice([x2 - 2*radius, y2 - 2*radius, x2, y2], 0, 90, fill=fill_color) | |
| def _wrap_text(self, text: str, font: ImageFont.ImageFont, max_width: int) -> List[str]: | |
| """Wrap text to fit within specified width""" | |
| words = text.split() | |
| lines = [] | |
| current_line = [] | |
| for word in words: | |
| test_line = ' '.join(current_line + [word]) | |
| bbox = font.getbbox(test_line) | |
| line_width = bbox[2] - bbox[0] | |
| if line_width <= max_width: | |
| current_line.append(word) | |
| else: | |
| if current_line: | |
| lines.append(' '.join(current_line)) | |
| current_line = [word] | |
| else: | |
| # Single word too long, force it | |
| lines.append(word) | |
| if current_line: | |
| lines.append(' '.join(current_line)) | |
| return lines | |
| def _apply_effects(self, img: Image.Image, layout_data: Dict) -> Image.Image: | |
| """Apply post-processing effects""" | |
| canvas = layout_data.get('canvas', {}) | |
| # Apply subtle blur for depth (optional) | |
| # img = img.filter(ImageFilter.UnsharpMask(radius=1, percent=110, threshold=2)) | |
| return img | |
| def _get_default_font(self) -> str: | |
| """Get path to default font""" | |
| try: | |
| # Try to find Arial or similar system font | |
| available_fonts = font_manager.findSystemFonts() | |
| for font_path in available_fonts: | |
| if 'arial' in font_path.lower() or 'liberation' in font_path.lower(): | |
| return font_path | |
| # Fallback to first available font | |
| return available_fonts[0] if available_fonts else None | |
| except: | |
| return None | |
| def _create_error_image(self) -> str: | |
| """Create error image when generation fails""" | |
| img = Image.new('RGB', (1080, 1920), '#f8f9fa') | |
| draw = ImageDraw.Draw(img) | |
| # Draw error message | |
| try: | |
| font = ImageFont.truetype(self.default_font_path, 24) | |
| except: | |
| font = ImageFont.load_default() | |
| error_text = "Error generating infographic\nPlease try again" | |
| bbox = draw.textbbox((0, 0), error_text, font=font) | |
| text_width = bbox[2] - bbox[0] | |
| text_height = bbox[3] - bbox[1] | |
| x = (1080 - text_width) // 2 | |
| y = (1920 - text_height) // 2 | |
| draw.text((x, y), error_text, fill='#e74c3c', font=font) | |
| error_path = "/tmp/error_infographic.png" | |
| img.save(error_path, 'PNG') | |
| return error_path | |
| def create_multiple_variations(self, layout_data: Dict, count: int = 3) -> List[str]: | |
| """Create multiple variations of the same infographic""" | |
| variations = [] | |
| for i in range(count): | |
| # Modify layout slightly for each variation | |
| variation_data = self._create_variation(layout_data, i) | |
| variation_path = self.create_infographic(variation_data) | |
| variations.append(variation_path) | |
| return variations | |
| def _create_variation(self, layout_data: Dict, variation_index: int) -> Dict: | |
| """Create a variation of the layout""" | |
| variation = layout_data.copy() | |
| # Modify colors slightly | |
| if variation_index == 1: | |
| # Darker variation | |
| variation['canvas']['background'] = '#f5f6fa' | |
| elif variation_index == 2: | |
| # Lighter variation | |
| variation['canvas']['background'] = '#ffffff' | |
| return variation | |
| def export_to_pdf(self, image_path: str) -> str: | |
| """Convert PNG to PDF""" | |
| try: | |
| from reportlab.pdfgen import canvas | |
| from reportlab.lib.utils import ImageReader | |
| img = Image.open(image_path) | |
| pdf_path = image_path.replace('.png', '.pdf') | |
| # Create PDF | |
| c = canvas.Canvas(pdf_path, pagesize=(img.width, img.height)) | |
| c.drawImage(ImageReader(img), 0, 0, width=img.width, height=img.height) | |
| c.save() | |
| return pdf_path | |
| except Exception as e: | |
| logger.error(f"PDF export failed: {e}") | |
| return image_path | |
| def export_to_svg(self, layout_data: Dict) -> str: | |
| """Export layout as SVG""" | |
| try: | |
| canvas = layout_data.get('canvas', {}) | |
| elements = layout_data.get('elements', []) | |
| svg_content = f'''<?xml version="1.0" encoding="UTF-8"?> | |
| <svg width="{canvas.get('width', 1080)}" height="{canvas.get('height', 1920)}" | |
| xmlns="http://www.w3.org/2000/svg"> | |
| <rect width="100%" height="100%" fill="{canvas.get('background', '#ffffff')}"/> | |
| ''' | |
| # Add elements as SVG | |
| for element in elements: | |
| svg_content += self._element_to_svg(element) | |
| svg_content += '</svg>' | |
| svg_path = f"/tmp/infographic_{id(layout_data)}.svg" | |
| with open(svg_path, 'w', encoding='utf-8') as f: | |
| f.write(svg_content) | |
| return svg_path | |
| except Exception as e: | |
| logger.error(f"SVG export failed: {e}") | |
| return "" | |
| def _element_to_svg(self, element: Dict) -> str: | |
| """Convert element to SVG markup""" | |
| element_type = element.get('type', 'text') | |
| position = element.get('position', {'x': 0, 'y': 0}) | |
| size = element.get('size', {'width': 100, 'height': 50}) | |
| styling = element.get('styling', {}) | |
| content = element.get('content', '') | |
| if element_type in ['title', 'section']: | |
| color = styling.get('color', '#2c3e50') | |
| font_size = styling.get('font', ['Arial', 16, 'normal'])[1] | |
| return f'''<text x="{position['x']}" y="{position['y'] + font_size}" | |
| fill="{color}" font-size="{font_size}">{content}</text>\n''' | |
| elif element_type == 'icon': | |
| color = styling.get('color', '#3498db') | |
| cx = position['x'] + size['width'] // 2 | |
| cy = position['y'] + size['height'] // 2 | |
| r = min(size['width'], size['height']) // 3 | |
| return f'''<circle cx="{cx}" cy="{cy}" r="{r}" fill="{color}"/>\n''' | |
| return "" | |
| def cleanup_temp_files(self): | |
| """Clean up temporary generated files""" | |
| for file_path in self.generated_images: | |
| try: | |
| if os.path.exists(file_path): | |
| os.remove(file_path) | |
| except: | |
| pass | |
| self.generated_images.clear() |