File size: 6,268 Bytes
456e710 13b11e7 0fb620e 456e710 13b11e7 0fb620e 13b11e7 456e710 13b11e7 0fb620e 13b11e7 05c1957 ee17ed9 05c1957 ee17ed9 13b11e7 ee17ed9 13b11e7 456e710 ee17ed9 456e710 0fb620e 456e710 0fb620e ee17ed9 456e710 13b11e7 456e710 0fb620e 456e710 0fb620e 456e710 13b11e7 ee17ed9 456e710 13b11e7 05c1957 0fb620e 05c1957 0fb620e 13b11e7 0fb620e 13b11e7 0fb620e 13b11e7 ee17ed9 456e710 ee17ed9 456e710 ee17ed9 13b11e7 0fb620e ee17ed9 456e710 0fb620e 13b11e7 456e710 13b11e7 ee17ed9 456e710 13b11e7 456e710 13b11e7 0fb620e 13b11e7 0fb620e 13b11e7 ee17ed9 456e710 ee17ed9 456e710 13b11e7 456e710 0fb620e 456e710 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import gradio as gr
import graphviz
import os
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
from PIL import Image, ImageDraw, ImageFont
# --- 1. SETUP: LOAD THE MODEL LOCALLY ---
# We are no longer using the Inference API. This loads the model into the Space's memory.
# This happens only once when the app starts up.
print("--- Initializing Local Model ---")
# This model is small enough to run on a free CPU Space and is excellent at following instructions.
MODEL_ID = "google/flan-t5-base"
# Determine the device. Free Spaces run on CPU.
DEVICE = "cpu"
print(f"--- Using device: {DEVICE} ---")
# Load the model's tokenizer and the model itself.
# This might take a few minutes the first time the Space starts.
try:
tokenizer = T5Tokenizer.from_pretrained(MODEL_ID)
model = T5ForConditionalGeneration.from_pretrained(MODEL_ID).to(DEVICE)
print(f"--- Model {MODEL_ID} loaded successfully ---")
except Exception as e:
print(f"Error loading model: {e}")
# Handle model loading failure gracefully in the UI later
tokenizer, model = None, None
# --- 2. DEFINE THE PROMPT TEMPLATE ---
# A structured prompt is key to getting good results.
SYSTEM_PROMPT_TEMPLATE = """Task: Generate a flowchart description in the Graphviz DOT language based on the following text.
Your response MUST be ONLY the Graphviz DOT language source code for a directed graph (digraph).
- The graph should be top-to-bottom (`rankdir=TB`).
- Use rounded boxes for process steps (`shape=box, style="rounded,filled", fillcolor="#EAEAFB"`).
- Use diamonds for decision points (`shape=diamond, fillcolor="#F9EED5"`).
- Use ellipses for start and end nodes (`shape=ellipse, fillcolor="#D5D6F9"`).
Text: "{user_prompt}"
DOT Language Code:"""
# --- 3. HELPER AND CORE FUNCTIONS ---
def create_placeholder_image(text="Flowchart will be generated here", size=(600, 800), path="placeholder.png"):
"""Creates a placeholder or error image with text."""
try:
img = Image.new('RGB', size, color=(255, 255, 255))
draw = ImageDraw.Draw(img)
try: font = ImageFont.truetype("DejaVuSans.ttf", 24)
except IOError: font = ImageFont.load_default()
bbox = draw.textbbox((0, 0), text, font=font)
text_width, text_height = bbox[2] - bbox[0], bbox[3] - bbox[1]
position = ((size[0] - text_width) / 2, (size[1] - text_height) / 2)
draw.text(position, text, fill=(200, 200, 200), font=font)
img.save(path)
return path
except Exception:
return None
def generate_flowchart(prompt: str):
"""
Generates a flowchart using the LOCALLY loaded model. No API token is needed.
"""
# Check if the model failed to load on startup
if not model or not tokenizer:
return create_placeholder_image("Error: AI Model failed to load on startup. Please check the logs."), None
if not prompt:
return create_placeholder_image("Please enter a prompt to generate a flowchart."), None
try:
# 1. Prepare the full prompt and convert it to tokens
full_prompt = SYSTEM_PROMPT_TEMPLATE.format(user_prompt=prompt)
inputs = tokenizer(full_prompt, return_tensors="pt").input_ids.to(DEVICE)
# 2. Generate the output from the local model
outputs = model.generate(inputs, max_new_tokens=1024, temperature=0.8, do_sample=True)
dot_code = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
# 3. Clean up the generated DOT code
if dot_code.startswith("```dot"): dot_code = dot_code[len("```dot"):].strip()
if dot_code.startswith("```"): dot_code = dot_code[len("```"):].strip()
if dot_code.endswith("```"): dot_code = dot_code[:-len("```")].strip()
if not dot_code.startswith("digraph"): dot_code = "digraph G {\n" + dot_code + "\n}"
# 4. Render the DOT code into an image using Graphviz
graph = graphviz.Source(dot_code)
output_path = graph.render(os.path.join("outputs", "flowchart"), format='png', cleanup=True)
return output_path, gr.update(value=output_path, visible=True)
except Exception as e:
print(f"An error occurred during generation: {e}")
error_message = f"An error occurred during generation.\nThe AI might have produced invalid flowchart code, or another issue occurred.\n\nDetails: {str(e)}"
return create_placeholder_image(error_message), gr.update(visible=False)
# --- 4. GRADIO UI ---
css = "footer {display: none !important} .gradio-container {background-color: #f8f9fa}"
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
gr.Markdown("# AI Flowchart Generator (Self-Contained)")
gr.Markdown("This version runs a free, open-source model directly in this Space. No API keys or monthly limits!")
with gr.Group():
with gr.Row(equal_height=False):
with gr.Column(scale=1):
prompt_input = gr.Textbox(lines=10, placeholder="e.g., Explain the process of making a cup of tea", label="Enter your process description here")
generate_btn = gr.Button("✨ Generate", variant="primary")
with gr.Column(scale=1):
output_image = gr.Image(label="Generated Flowchart", type="filepath", interactive=False, value=create_placeholder_image(), height=600, show_label=False)
download_btn = gr.DownloadButton("⬇️ Download", variant="primary", visible=False)
def on_generate_click(prompt):
# Provide user feedback that generation is in progress
yield (gr.update(interactive=False), gr.update(visible=False), create_placeholder_image("🧠 Thinking... Please wait.\n(First generation can be slow)"))
# The generate_flowchart function no longer needs a token
img_path, download_btn_update = generate_flowchart(prompt)
# Update UI with the result
yield (gr.update(interactive=True), download_btn_update, img_path)
generate_btn.click(
fn=on_generate_click,
inputs=[prompt_input],
outputs=[generate_btn, download_btn, output_image]
)
if __name__ == "__main__":
if not os.path.exists("outputs"): os.makedirs("outputs")
demo.launch() |