|
|
import gradio as gr |
|
|
import graphviz |
|
|
import os |
|
|
import torch |
|
|
from transformers import T5Tokenizer, T5ForConditionalGeneration |
|
|
from PIL import Image, ImageDraw, ImageFont |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("--- Initializing Local Model ---") |
|
|
|
|
|
MODEL_ID = "google/flan-t5-base" |
|
|
|
|
|
|
|
|
DEVICE = "cpu" |
|
|
print(f"--- Using device: {DEVICE} ---") |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
tokenizer = T5Tokenizer.from_pretrained(MODEL_ID) |
|
|
model = T5ForConditionalGeneration.from_pretrained(MODEL_ID).to(DEVICE) |
|
|
print(f"--- Model {MODEL_ID} loaded successfully ---") |
|
|
except Exception as e: |
|
|
print(f"Error loading model: {e}") |
|
|
|
|
|
tokenizer, model = None, None |
|
|
|
|
|
|
|
|
|
|
|
SYSTEM_PROMPT_TEMPLATE = """Task: Generate a flowchart description in the Graphviz DOT language based on the following text. |
|
|
Your response MUST be ONLY the Graphviz DOT language source code for a directed graph (digraph). |
|
|
- The graph should be top-to-bottom (`rankdir=TB`). |
|
|
- Use rounded boxes for process steps (`shape=box, style="rounded,filled", fillcolor="#EAEAFB"`). |
|
|
- Use diamonds for decision points (`shape=diamond, fillcolor="#F9EED5"`). |
|
|
- Use ellipses for start and end nodes (`shape=ellipse, fillcolor="#D5D6F9"`). |
|
|
|
|
|
Text: "{user_prompt}" |
|
|
|
|
|
DOT Language Code:""" |
|
|
|
|
|
|
|
|
|
|
|
def create_placeholder_image(text="Flowchart will be generated here", size=(600, 800), path="placeholder.png"): |
|
|
"""Creates a placeholder or error image with text.""" |
|
|
try: |
|
|
img = Image.new('RGB', size, color=(255, 255, 255)) |
|
|
draw = ImageDraw.Draw(img) |
|
|
try: font = ImageFont.truetype("DejaVuSans.ttf", 24) |
|
|
except IOError: font = ImageFont.load_default() |
|
|
bbox = draw.textbbox((0, 0), text, font=font) |
|
|
text_width, text_height = bbox[2] - bbox[0], bbox[3] - bbox[1] |
|
|
position = ((size[0] - text_width) / 2, (size[1] - text_height) / 2) |
|
|
draw.text(position, text, fill=(200, 200, 200), font=font) |
|
|
img.save(path) |
|
|
return path |
|
|
except Exception: |
|
|
return None |
|
|
|
|
|
def generate_flowchart(prompt: str): |
|
|
""" |
|
|
Generates a flowchart using the LOCALLY loaded model. No API token is needed. |
|
|
""" |
|
|
|
|
|
if not model or not tokenizer: |
|
|
return create_placeholder_image("Error: AI Model failed to load on startup. Please check the logs."), None |
|
|
|
|
|
if not prompt: |
|
|
return create_placeholder_image("Please enter a prompt to generate a flowchart."), None |
|
|
|
|
|
try: |
|
|
|
|
|
full_prompt = SYSTEM_PROMPT_TEMPLATE.format(user_prompt=prompt) |
|
|
inputs = tokenizer(full_prompt, return_tensors="pt").input_ids.to(DEVICE) |
|
|
|
|
|
|
|
|
outputs = model.generate(inputs, max_new_tokens=1024, temperature=0.8, do_sample=True) |
|
|
dot_code = tokenizer.decode(outputs[0], skip_special_tokens=True).strip() |
|
|
|
|
|
|
|
|
if dot_code.startswith("```dot"): dot_code = dot_code[len("```dot"):].strip() |
|
|
if dot_code.startswith("```"): dot_code = dot_code[len("```"):].strip() |
|
|
if dot_code.endswith("```"): dot_code = dot_code[:-len("```")].strip() |
|
|
if not dot_code.startswith("digraph"): dot_code = "digraph G {\n" + dot_code + "\n}" |
|
|
|
|
|
|
|
|
|
|
|
graph = graphviz.Source(dot_code) |
|
|
output_path = graph.render(os.path.join("outputs", "flowchart"), format='png', cleanup=True) |
|
|
|
|
|
return output_path, gr.update(value=output_path, visible=True) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"An error occurred during generation: {e}") |
|
|
error_message = f"An error occurred during generation.\nThe AI might have produced invalid flowchart code, or another issue occurred.\n\nDetails: {str(e)}" |
|
|
return create_placeholder_image(error_message), gr.update(visible=False) |
|
|
|
|
|
|
|
|
css = "footer {display: none !important} .gradio-container {background-color: #f8f9fa}" |
|
|
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: |
|
|
gr.Markdown("# AI Flowchart Generator (Self-Contained)") |
|
|
gr.Markdown("This version runs a free, open-source model directly in this Space. No API keys or monthly limits!") |
|
|
|
|
|
with gr.Group(): |
|
|
with gr.Row(equal_height=False): |
|
|
with gr.Column(scale=1): |
|
|
prompt_input = gr.Textbox(lines=10, placeholder="e.g., Explain the process of making a cup of tea", label="Enter your process description here") |
|
|
generate_btn = gr.Button("✨ Generate", variant="primary") |
|
|
with gr.Column(scale=1): |
|
|
output_image = gr.Image(label="Generated Flowchart", type="filepath", interactive=False, value=create_placeholder_image(), height=600, show_label=False) |
|
|
download_btn = gr.DownloadButton("⬇️ Download", variant="primary", visible=False) |
|
|
|
|
|
def on_generate_click(prompt): |
|
|
|
|
|
yield (gr.update(interactive=False), gr.update(visible=False), create_placeholder_image("🧠 Thinking... Please wait.\n(First generation can be slow)")) |
|
|
|
|
|
img_path, download_btn_update = generate_flowchart(prompt) |
|
|
|
|
|
yield (gr.update(interactive=True), download_btn_update, img_path) |
|
|
|
|
|
generate_btn.click( |
|
|
fn=on_generate_click, |
|
|
inputs=[prompt_input], |
|
|
outputs=[generate_btn, download_btn, output_image] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
if not os.path.exists("outputs"): os.makedirs("outputs") |
|
|
demo.launch() |