fc / app.py
Diggz10's picture
Update app.py
13b11e7 verified
import gradio as gr
import graphviz
import os
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
from PIL import Image, ImageDraw, ImageFont
# --- 1. SETUP: LOAD THE MODEL LOCALLY ---
# We are no longer using the Inference API. This loads the model into the Space's memory.
# This happens only once when the app starts up.
print("--- Initializing Local Model ---")
# This model is small enough to run on a free CPU Space and is excellent at following instructions.
MODEL_ID = "google/flan-t5-base"
# Determine the device. Free Spaces run on CPU.
DEVICE = "cpu"
print(f"--- Using device: {DEVICE} ---")
# Load the model's tokenizer and the model itself.
# This might take a few minutes the first time the Space starts.
try:
tokenizer = T5Tokenizer.from_pretrained(MODEL_ID)
model = T5ForConditionalGeneration.from_pretrained(MODEL_ID).to(DEVICE)
print(f"--- Model {MODEL_ID} loaded successfully ---")
except Exception as e:
print(f"Error loading model: {e}")
# Handle model loading failure gracefully in the UI later
tokenizer, model = None, None
# --- 2. DEFINE THE PROMPT TEMPLATE ---
# A structured prompt is key to getting good results.
SYSTEM_PROMPT_TEMPLATE = """Task: Generate a flowchart description in the Graphviz DOT language based on the following text.
Your response MUST be ONLY the Graphviz DOT language source code for a directed graph (digraph).
- The graph should be top-to-bottom (`rankdir=TB`).
- Use rounded boxes for process steps (`shape=box, style="rounded,filled", fillcolor="#EAEAFB"`).
- Use diamonds for decision points (`shape=diamond, fillcolor="#F9EED5"`).
- Use ellipses for start and end nodes (`shape=ellipse, fillcolor="#D5D6F9"`).
Text: "{user_prompt}"
DOT Language Code:"""
# --- 3. HELPER AND CORE FUNCTIONS ---
def create_placeholder_image(text="Flowchart will be generated here", size=(600, 800), path="placeholder.png"):
"""Creates a placeholder or error image with text."""
try:
img = Image.new('RGB', size, color=(255, 255, 255))
draw = ImageDraw.Draw(img)
try: font = ImageFont.truetype("DejaVuSans.ttf", 24)
except IOError: font = ImageFont.load_default()
bbox = draw.textbbox((0, 0), text, font=font)
text_width, text_height = bbox[2] - bbox[0], bbox[3] - bbox[1]
position = ((size[0] - text_width) / 2, (size[1] - text_height) / 2)
draw.text(position, text, fill=(200, 200, 200), font=font)
img.save(path)
return path
except Exception:
return None
def generate_flowchart(prompt: str):
"""
Generates a flowchart using the LOCALLY loaded model. No API token is needed.
"""
# Check if the model failed to load on startup
if not model or not tokenizer:
return create_placeholder_image("Error: AI Model failed to load on startup. Please check the logs."), None
if not prompt:
return create_placeholder_image("Please enter a prompt to generate a flowchart."), None
try:
# 1. Prepare the full prompt and convert it to tokens
full_prompt = SYSTEM_PROMPT_TEMPLATE.format(user_prompt=prompt)
inputs = tokenizer(full_prompt, return_tensors="pt").input_ids.to(DEVICE)
# 2. Generate the output from the local model
outputs = model.generate(inputs, max_new_tokens=1024, temperature=0.8, do_sample=True)
dot_code = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
# 3. Clean up the generated DOT code
if dot_code.startswith("```dot"): dot_code = dot_code[len("```dot"):].strip()
if dot_code.startswith("```"): dot_code = dot_code[len("```"):].strip()
if dot_code.endswith("```"): dot_code = dot_code[:-len("```")].strip()
if not dot_code.startswith("digraph"): dot_code = "digraph G {\n" + dot_code + "\n}"
# 4. Render the DOT code into an image using Graphviz
graph = graphviz.Source(dot_code)
output_path = graph.render(os.path.join("outputs", "flowchart"), format='png', cleanup=True)
return output_path, gr.update(value=output_path, visible=True)
except Exception as e:
print(f"An error occurred during generation: {e}")
error_message = f"An error occurred during generation.\nThe AI might have produced invalid flowchart code, or another issue occurred.\n\nDetails: {str(e)}"
return create_placeholder_image(error_message), gr.update(visible=False)
# --- 4. GRADIO UI ---
css = "footer {display: none !important} .gradio-container {background-color: #f8f9fa}"
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
gr.Markdown("# AI Flowchart Generator (Self-Contained)")
gr.Markdown("This version runs a free, open-source model directly in this Space. No API keys or monthly limits!")
with gr.Group():
with gr.Row(equal_height=False):
with gr.Column(scale=1):
prompt_input = gr.Textbox(lines=10, placeholder="e.g., Explain the process of making a cup of tea", label="Enter your process description here")
generate_btn = gr.Button("✨ Generate", variant="primary")
with gr.Column(scale=1):
output_image = gr.Image(label="Generated Flowchart", type="filepath", interactive=False, value=create_placeholder_image(), height=600, show_label=False)
download_btn = gr.DownloadButton("⬇️ Download", variant="primary", visible=False)
def on_generate_click(prompt):
# Provide user feedback that generation is in progress
yield (gr.update(interactive=False), gr.update(visible=False), create_placeholder_image("🧠 Thinking... Please wait.\n(First generation can be slow)"))
# The generate_flowchart function no longer needs a token
img_path, download_btn_update = generate_flowchart(prompt)
# Update UI with the result
yield (gr.update(interactive=True), download_btn_update, img_path)
generate_btn.click(
fn=on_generate_click,
inputs=[prompt_input],
outputs=[generate_btn, download_btn, output_image]
)
if __name__ == "__main__":
if not os.path.exists("outputs"): os.makedirs("outputs")
demo.launch()