import gradio as gr import graphviz import os import torch from transformers import T5Tokenizer, T5ForConditionalGeneration from PIL import Image, ImageDraw, ImageFont # --- 1. SETUP: LOAD THE MODEL LOCALLY --- # We are no longer using the Inference API. This loads the model into the Space's memory. # This happens only once when the app starts up. print("--- Initializing Local Model ---") # This model is small enough to run on a free CPU Space and is excellent at following instructions. MODEL_ID = "google/flan-t5-base" # Determine the device. Free Spaces run on CPU. DEVICE = "cpu" print(f"--- Using device: {DEVICE} ---") # Load the model's tokenizer and the model itself. # This might take a few minutes the first time the Space starts. try: tokenizer = T5Tokenizer.from_pretrained(MODEL_ID) model = T5ForConditionalGeneration.from_pretrained(MODEL_ID).to(DEVICE) print(f"--- Model {MODEL_ID} loaded successfully ---") except Exception as e: print(f"Error loading model: {e}") # Handle model loading failure gracefully in the UI later tokenizer, model = None, None # --- 2. DEFINE THE PROMPT TEMPLATE --- # A structured prompt is key to getting good results. SYSTEM_PROMPT_TEMPLATE = """Task: Generate a flowchart description in the Graphviz DOT language based on the following text. Your response MUST be ONLY the Graphviz DOT language source code for a directed graph (digraph). - The graph should be top-to-bottom (`rankdir=TB`). - Use rounded boxes for process steps (`shape=box, style="rounded,filled", fillcolor="#EAEAFB"`). - Use diamonds for decision points (`shape=diamond, fillcolor="#F9EED5"`). - Use ellipses for start and end nodes (`shape=ellipse, fillcolor="#D5D6F9"`). Text: "{user_prompt}" DOT Language Code:""" # --- 3. HELPER AND CORE FUNCTIONS --- def create_placeholder_image(text="Flowchart will be generated here", size=(600, 800), path="placeholder.png"): """Creates a placeholder or error image with text.""" try: img = Image.new('RGB', size, color=(255, 255, 255)) draw = ImageDraw.Draw(img) try: font = ImageFont.truetype("DejaVuSans.ttf", 24) except IOError: font = ImageFont.load_default() bbox = draw.textbbox((0, 0), text, font=font) text_width, text_height = bbox[2] - bbox[0], bbox[3] - bbox[1] position = ((size[0] - text_width) / 2, (size[1] - text_height) / 2) draw.text(position, text, fill=(200, 200, 200), font=font) img.save(path) return path except Exception: return None def generate_flowchart(prompt: str): """ Generates a flowchart using the LOCALLY loaded model. No API token is needed. """ # Check if the model failed to load on startup if not model or not tokenizer: return create_placeholder_image("Error: AI Model failed to load on startup. Please check the logs."), None if not prompt: return create_placeholder_image("Please enter a prompt to generate a flowchart."), None try: # 1. Prepare the full prompt and convert it to tokens full_prompt = SYSTEM_PROMPT_TEMPLATE.format(user_prompt=prompt) inputs = tokenizer(full_prompt, return_tensors="pt").input_ids.to(DEVICE) # 2. Generate the output from the local model outputs = model.generate(inputs, max_new_tokens=1024, temperature=0.8, do_sample=True) dot_code = tokenizer.decode(outputs[0], skip_special_tokens=True).strip() # 3. Clean up the generated DOT code if dot_code.startswith("```dot"): dot_code = dot_code[len("```dot"):].strip() if dot_code.startswith("```"): dot_code = dot_code[len("```"):].strip() if dot_code.endswith("```"): dot_code = dot_code[:-len("```")].strip() if not dot_code.startswith("digraph"): dot_code = "digraph G {\n" + dot_code + "\n}" # 4. Render the DOT code into an image using Graphviz graph = graphviz.Source(dot_code) output_path = graph.render(os.path.join("outputs", "flowchart"), format='png', cleanup=True) return output_path, gr.update(value=output_path, visible=True) except Exception as e: print(f"An error occurred during generation: {e}") error_message = f"An error occurred during generation.\nThe AI might have produced invalid flowchart code, or another issue occurred.\n\nDetails: {str(e)}" return create_placeholder_image(error_message), gr.update(visible=False) # --- 4. GRADIO UI --- css = "footer {display: none !important} .gradio-container {background-color: #f8f9fa}" with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: gr.Markdown("# AI Flowchart Generator (Self-Contained)") gr.Markdown("This version runs a free, open-source model directly in this Space. No API keys or monthly limits!") with gr.Group(): with gr.Row(equal_height=False): with gr.Column(scale=1): prompt_input = gr.Textbox(lines=10, placeholder="e.g., Explain the process of making a cup of tea", label="Enter your process description here") generate_btn = gr.Button("✨ Generate", variant="primary") with gr.Column(scale=1): output_image = gr.Image(label="Generated Flowchart", type="filepath", interactive=False, value=create_placeholder_image(), height=600, show_label=False) download_btn = gr.DownloadButton("⬇️ Download", variant="primary", visible=False) def on_generate_click(prompt): # Provide user feedback that generation is in progress yield (gr.update(interactive=False), gr.update(visible=False), create_placeholder_image("🧠 Thinking... Please wait.\n(First generation can be slow)")) # The generate_flowchart function no longer needs a token img_path, download_btn_update = generate_flowchart(prompt) # Update UI with the result yield (gr.update(interactive=True), download_btn_update, img_path) generate_btn.click( fn=on_generate_click, inputs=[prompt_input], outputs=[generate_btn, download_btn, output_image] ) if __name__ == "__main__": if not os.path.exists("outputs"): os.makedirs("outputs") demo.launch()