Spaces:
Sleeping
Sleeping
| import os | |
| from PIL import Image, ImageDraw, ImageFont | |
| import json | |
| import gradio as gr | |
| from google import genai | |
| from google.genai import types | |
| # Initialize Google Gemini client | |
| client = genai.Client(api_key=os.environ['GOOGLE_API_KEY']) | |
| model_name = "gemini-2.0-flash-exp" | |
| # Function to parse JSON output from Gemini | |
| def parse_json(json_output): | |
| """ | |
| Parse JSON output from the Gemini model. | |
| """ | |
| try: | |
| lines = json_output.splitlines() | |
| for i, line in enumerate(lines): | |
| if line == "```json": | |
| json_output = "\n".join(lines[i + 1:]) # Remove everything before "```json" | |
| json_output = json_output.split("```")[0] # Remove everything after the closing "```" | |
| break | |
| return json.loads(json_output) | |
| except Exception as e: | |
| print(f"Error parsing JSON: {e}") | |
| return {} | |
| # Function to draw a flowchart | |
| def draw_flowchart(image, flowchart_json): | |
| """ | |
| Draws a flowchart on the given image based on JSON input. | |
| """ | |
| im = image.copy() | |
| draw = ImageDraw.Draw(im) | |
| # Load default font | |
| try: | |
| font = ImageFont.load_default() | |
| except Exception as e: | |
| print(f"Error loading font: {e}") | |
| return im | |
| shapes = flowchart_json.get("shapes", []) | |
| connections = flowchart_json.get("connections", []) | |
| # Draw shapes | |
| for shape in shapes: | |
| x, y, w, h = shape["x"], shape["y"], shape["width"], shape["height"] | |
| shape_type = shape.get("type", "rectangle").lower() | |
| label = shape.get("label", "") | |
| color = shape.get("color", "white") | |
| # Draw the shape | |
| if shape_type == "rectangle": | |
| draw.rectangle([x, y, x + w, y + h], fill=color, outline="black", width=3) | |
| elif shape_type == "ellipse": | |
| draw.ellipse([x, y, x + w, y + h], fill=color, outline="black", width=3) | |
| elif shape_type == "diamond": | |
| points = [ | |
| (x + w // 2, y), # Top | |
| (x + w, y + h // 2), # Right | |
| (x + w // 2, y + h), # Bottom | |
| (x, y + h // 2) # Left | |
| ] | |
| draw.polygon(points, fill=color, outline="black") | |
| # Calculate text position using getbbox | |
| bbox = font.getbbox(label) | |
| text_w = bbox[2] - bbox[0] | |
| text_h = bbox[3] - bbox[1] | |
| text_x = x + (w - text_w) // 2 | |
| text_y = y + (h - text_h) // 2 | |
| # Add the label | |
| draw.text((text_x, text_y), label, fill="black", font=font) | |
| # Draw connections | |
| for conn in connections: | |
| from_shape = next(s for s in shapes if s["id"] == conn["from"]) | |
| to_shape = next(s for s in shapes if s["id"] == conn["to"]) | |
| x1, y1 = from_shape["x"] + from_shape["width"] // 2, from_shape["y"] + from_shape["height"] | |
| x2, y2 = to_shape["x"] + to_shape["width"] // 2, to_shape["y"] | |
| # Draw the line | |
| draw.line([x1, y1, x2, y2], fill="black", width=2) | |
| # Add arrowhead for arrows | |
| if conn.get("type", "arrow") == "arrow": | |
| arrow_size = 10 | |
| draw.polygon([(x2, y2 - arrow_size), (x2, y2 + arrow_size), (x2 + arrow_size, y2)], fill="black") | |
| return im | |
| # Function to draw a flowchart | |
| # Function to draw a flowchart | |
| def olddraw_flowchart(image, flowchart_json): | |
| """ | |
| Draws a flowchart on the given image based on JSON input. | |
| """ | |
| im = image.copy() | |
| draw = ImageDraw.Draw(im) | |
| # Load default font | |
| try: | |
| font = ImageFont.load_default() | |
| except Exception as e: | |
| print(f"Error loading font: {e}") | |
| return im | |
| shapes = flowchart_json.get("shapes", []) | |
| connections = flowchart_json.get("connections", []) | |
| # Draw shapes | |
| for shape in shapes: | |
| x, y, w, h = shape["x"], shape["y"], shape["width"], shape["height"] | |
| shape_type = shape.get("type", "rectangle").lower() | |
| label = shape.get("label", "") | |
| color = shape.get("color", "white") | |
| # Draw the shape | |
| if shape_type == "rectangle": | |
| draw.rectangle([x, y, x + w, y + h], fill=color, outline="black", width=3) | |
| elif shape_type == "ellipse": | |
| draw.ellipse([x, y, x + w, y + h], fill=color, outline="black", width=3) | |
| elif shape_type == "diamond": | |
| points = [ | |
| (x + w // 2, y), # Top | |
| (x + w, y + h // 2), # Right | |
| (x + w // 2, y + h), # Bottom | |
| (x, y + h // 2) # Left | |
| ] | |
| draw.polygon(points, fill=color, outline="black") | |
| # Calculate text position | |
| text_w, text_h = font.getsize(label) | |
| text_x = x + (w - text_w) // 2 | |
| text_y = y + (h - text_h) // 2 | |
| # Add the label | |
| draw.text((text_x, text_y), label, fill="black", font=font) | |
| # Draw connections | |
| for conn in connections: | |
| from_shape = next(s for s in shapes if s["id"] == conn["from"]) | |
| to_shape = next(s for s in shapes if s["id"] == conn["to"]) | |
| x1, y1 = from_shape["x"] + from_shape["width"] // 2, from_shape["y"] + from_shape["height"] | |
| x2, y2 = to_shape["x"] + to_shape["width"] // 2, to_shape["y"] | |
| # Draw the line | |
| draw.line([x1, y1, x2, y2], fill="black", width=2) | |
| # Add arrowhead for arrows | |
| if conn.get("type", "arrow") == "arrow": | |
| arrow_size = 10 | |
| draw.polygon([(x2, y2 - arrow_size), (x2, y2 + arrow_size), (x2 + arrow_size, y2)], fill="black") | |
| return im | |
| # Function to generate flowchart JSON via Gemini | |
| def generate_flowchart(prompt): | |
| """ | |
| Use Google Gemini to generate JSON for a flowchart. | |
| """ | |
| try: | |
| response = client.models.generate_content( | |
| model=model_name, | |
| contents=[prompt], | |
| config=types.GenerateContentConfig( | |
| system_instruction=""" | |
| Return a JSON structure describing a flowchart. | |
| Use formal flowchart conventions with shapes like rectangles, ellipses, and diamonds. | |
| Each shape should have attributes: id, label, x, y, width, height, type (e.g., 'rectangle', 'ellipse', 'diamond'), and color. | |
| Also include connections with attributes: from (id), to (id), and type (e.g., 'arrow'). | |
| """, | |
| temperature=0.5, | |
| ) | |
| ) | |
| print("Gemini Response:", response.text) | |
| return parse_json(response.text) | |
| except Exception as e: | |
| print(f"Error generating flowchart JSON: {e}") | |
| return {} | |
| # Function to predict the flowchart | |
| def predict_flowchart(prompt): | |
| """ | |
| Generate a flowchart image based on the user's prompt. | |
| """ | |
| try: | |
| # Generate the flowchart JSON | |
| flowchart_json = generate_flowchart(prompt) | |
| if not flowchart_json: | |
| raise ValueError("Could not generate flowchart JSON.") | |
| # Create a blank image to draw on | |
| image = Image.new("RGB", (1000, 800), "white") | |
| result_image = draw_flowchart(image, flowchart_json) | |
| return result_image | |
| except Exception as e: | |
| print(f"Error during processing: {e}") | |
| # Return a blank image in case of an error | |
| error_image = Image.new("RGB", (1000, 800), "white") | |
| draw = ImageDraw.Draw(error_image) | |
| draw.text((50, 50), f"Error: {str(e)}", fill="red") | |
| return error_image | |
| # Define the Gradio interface for flowcharts | |
| def gradio_interface_flowcharts(): | |
| """ | |
| Gradio app interface for flowchart generation. | |
| """ | |
| with gr.Blocks(gr.themes.Glass(secondary_hue="blue")) as demo: | |
| gr.Markdown("# Flowchart Generator with Gemini") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Input Section") | |
| input_prompt = gr.Textbox(lines=2, label="Input Prompt", placeholder="Describe the flowchart process.") | |
| submit_btn = gr.Button("Generate Flowchart") | |
| with gr.Column(): | |
| gr.Markdown("### Output Section") | |
| output_image = gr.Image(type="pil", label="Output Flowchart") | |
| # Event to generate flowcharts | |
| submit_btn.click( | |
| predict_flowchart, | |
| inputs=[input_prompt], | |
| outputs=[output_image] | |
| ) | |
| return demo | |
| # Run the app | |
| if __name__ == "__main__": | |
| demo = gradio_interface_flowcharts() | |
| demo.launch() |