File size: 6,268 Bytes
456e710
 
 
13b11e7
0fb620e
456e710
 
13b11e7
 
 
 
0fb620e
13b11e7
 
456e710
13b11e7
 
0fb620e
 
13b11e7
 
 
 
 
 
 
 
 
 
 
 
 
05c1957
ee17ed9
 
 
 
 
05c1957
 
 
 
ee17ed9
13b11e7
 
ee17ed9
13b11e7
456e710
ee17ed9
456e710
0fb620e
 
456e710
0fb620e
ee17ed9
456e710
 
 
13b11e7
 
456e710
0fb620e
456e710
0fb620e
456e710
13b11e7
 
 
 
ee17ed9
 
456e710
 
13b11e7
05c1957
0fb620e
05c1957
0fb620e
 
 
 
13b11e7
0fb620e
 
 
13b11e7
 
0fb620e
13b11e7
ee17ed9
456e710
 
ee17ed9
456e710
ee17ed9
13b11e7
0fb620e
ee17ed9
456e710
0fb620e
13b11e7
456e710
13b11e7
 
 
ee17ed9
456e710
 
13b11e7
 
456e710
13b11e7
 
0fb620e
13b11e7
 
 
 
0fb620e
13b11e7
 
ee17ed9
456e710
ee17ed9
456e710
13b11e7
456e710
 
 
0fb620e
456e710
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import gradio as gr
import graphviz
import os
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
from PIL import Image, ImageDraw, ImageFont

# --- 1. SETUP: LOAD THE MODEL LOCALLY ---
# We are no longer using the Inference API. This loads the model into the Space's memory.
# This happens only once when the app starts up.

print("--- Initializing Local Model ---")
# This model is small enough to run on a free CPU Space and is excellent at following instructions.
MODEL_ID = "google/flan-t5-base" 

# Determine the device. Free Spaces run on CPU.
DEVICE = "cpu"
print(f"--- Using device: {DEVICE} ---")

# Load the model's tokenizer and the model itself.
# This might take a few minutes the first time the Space starts.
try:
    tokenizer = T5Tokenizer.from_pretrained(MODEL_ID)
    model = T5ForConditionalGeneration.from_pretrained(MODEL_ID).to(DEVICE)
    print(f"--- Model {MODEL_ID} loaded successfully ---")
except Exception as e:
    print(f"Error loading model: {e}")
    # Handle model loading failure gracefully in the UI later
    tokenizer, model = None, None

# --- 2. DEFINE THE PROMPT TEMPLATE ---
# A structured prompt is key to getting good results.
SYSTEM_PROMPT_TEMPLATE = """Task: Generate a flowchart description in the Graphviz DOT language based on the following text.
Your response MUST be ONLY the Graphviz DOT language source code for a directed graph (digraph).
- The graph should be top-to-bottom (`rankdir=TB`).
- Use rounded boxes for process steps (`shape=box, style="rounded,filled", fillcolor="#EAEAFB"`).
- Use diamonds for decision points (`shape=diamond, fillcolor="#F9EED5"`).
- Use ellipses for start and end nodes (`shape=ellipse, fillcolor="#D5D6F9"`).

Text: "{user_prompt}"

DOT Language Code:"""


# --- 3. HELPER AND CORE FUNCTIONS ---
def create_placeholder_image(text="Flowchart will be generated here", size=(600, 800), path="placeholder.png"):
    """Creates a placeholder or error image with text."""
    try:
        img = Image.new('RGB', size, color=(255, 255, 255))
        draw = ImageDraw.Draw(img)
        try: font = ImageFont.truetype("DejaVuSans.ttf", 24)
        except IOError: font = ImageFont.load_default()
        bbox = draw.textbbox((0, 0), text, font=font)
        text_width, text_height = bbox[2] - bbox[0], bbox[3] - bbox[1]
        position = ((size[0] - text_width) / 2, (size[1] - text_height) / 2)
        draw.text(position, text, fill=(200, 200, 200), font=font)
        img.save(path)
        return path
    except Exception:
        return None

def generate_flowchart(prompt: str):
    """
    Generates a flowchart using the LOCALLY loaded model. No API token is needed.
    """
    # Check if the model failed to load on startup
    if not model or not tokenizer:
        return create_placeholder_image("Error: AI Model failed to load on startup. Please check the logs."), None
        
    if not prompt:
        return create_placeholder_image("Please enter a prompt to generate a flowchart."), None

    try:
        # 1. Prepare the full prompt and convert it to tokens
        full_prompt = SYSTEM_PROMPT_TEMPLATE.format(user_prompt=prompt)
        inputs = tokenizer(full_prompt, return_tensors="pt").input_ids.to(DEVICE)
        
        # 2. Generate the output from the local model
        outputs = model.generate(inputs, max_new_tokens=1024, temperature=0.8, do_sample=True)
        dot_code = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()

        # 3. Clean up the generated DOT code
        if dot_code.startswith("```dot"): dot_code = dot_code[len("```dot"):].strip()
        if dot_code.startswith("```"): dot_code = dot_code[len("```"):].strip()
        if dot_code.endswith("```"): dot_code = dot_code[:-len("```")].strip()
        if not dot_code.startswith("digraph"): dot_code = "digraph G {\n" + dot_code + "\n}"


        # 4. Render the DOT code into an image using Graphviz
        graph = graphviz.Source(dot_code)
        output_path = graph.render(os.path.join("outputs", "flowchart"), format='png', cleanup=True)
        
        return output_path, gr.update(value=output_path, visible=True)

    except Exception as e:
        print(f"An error occurred during generation: {e}")
        error_message = f"An error occurred during generation.\nThe AI might have produced invalid flowchart code, or another issue occurred.\n\nDetails: {str(e)}"
        return create_placeholder_image(error_message), gr.update(visible=False)

# --- 4. GRADIO UI ---
css = "footer {display: none !important} .gradio-container {background-color: #f8f9fa}"
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
    gr.Markdown("# AI Flowchart Generator (Self-Contained)")
    gr.Markdown("This version runs a free, open-source model directly in this Space. No API keys or monthly limits!")
    
    with gr.Group():
        with gr.Row(equal_height=False):
            with gr.Column(scale=1):
                prompt_input = gr.Textbox(lines=10, placeholder="e.g., Explain the process of making a cup of tea", label="Enter your process description here")
                generate_btn = gr.Button("✨ Generate", variant="primary")
            with gr.Column(scale=1):
                output_image = gr.Image(label="Generated Flowchart", type="filepath", interactive=False, value=create_placeholder_image(), height=600, show_label=False)
                download_btn = gr.DownloadButton("⬇️ Download", variant="primary", visible=False)
                
    def on_generate_click(prompt):
        # Provide user feedback that generation is in progress
        yield (gr.update(interactive=False), gr.update(visible=False), create_placeholder_image("🧠 Thinking... Please wait.\n(First generation can be slow)"))
        # The generate_flowchart function no longer needs a token
        img_path, download_btn_update = generate_flowchart(prompt)
        # Update UI with the result
        yield (gr.update(interactive=True), download_btn_update, img_path)
        
    generate_btn.click(
        fn=on_generate_click,
        inputs=[prompt_input],
        outputs=[generate_btn, download_btn, output_image]
    )

if __name__ == "__main__":
    if not os.path.exists("outputs"): os.makedirs("outputs")
    demo.launch()