Diggz10 commited on
Commit
13b11e7
·
verified ·
1 Parent(s): 0fb620e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -55
app.py CHANGED
@@ -1,27 +1,35 @@
1
  import gradio as gr
2
  import graphviz
3
  import os
 
4
  from transformers import T5Tokenizer, T5ForConditionalGeneration
5
  from PIL import Image, ImageDraw, ImageFont
6
- import torch
7
 
8
- # --- 1. MODEL LOADING (LOCALLY INSIDE THE SPACE) ---
9
- # No more Inference API! We are loading the model directly.
 
 
10
  print("--- Initializing Local Model ---")
11
- MODEL_ID = "google/flan-t5-base" # A powerful model small enough to run on a free CPU Space
 
12
 
13
- # Check for device
14
- DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
15
  print(f"--- Using device: {DEVICE} ---")
16
 
17
- # Load the tokenizer and model from the Hub into the Space's memory
18
- tokenizer = T5Tokenizer.from_pretrained(MODEL_ID)
19
- model = T5ForConditionalGeneration.from_pretrained(MODEL_ID).to(DEVICE)
20
- print(f"--- Model {MODEL_ID} Initialized Successfully ---")
21
-
22
-
23
- # --- 2. SETUP ---
24
- # The prompt template for our instruction-tuned model
 
 
 
 
 
25
  SYSTEM_PROMPT_TEMPLATE = """Task: Generate a flowchart description in the Graphviz DOT language based on the following text.
26
  Your response MUST be ONLY the Graphviz DOT language source code for a directed graph (digraph).
27
  - The graph should be top-to-bottom (`rankdir=TB`).
@@ -33,9 +41,10 @@ Text: "{user_prompt}"
33
 
34
  DOT Language Code:"""
35
 
36
- # --- Helper function for placeholder images
 
37
  def create_placeholder_image(text="Flowchart will be generated here", size=(600, 800), path="placeholder.png"):
38
- # (This function remains unchanged)
39
  try:
40
  img = Image.new('RGB', size, color=(255, 255, 255))
41
  draw = ImageDraw.Draw(img)
@@ -47,19 +56,22 @@ def create_placeholder_image(text="Flowchart will be generated here", size=(600,
47
  draw.text(position, text, fill=(200, 200, 200), font=font)
48
  img.save(path)
49
  return path
50
- except Exception: return None
 
51
 
52
-
53
- # --- 3. CORE AI AND RENDERING LOGIC ---
54
  def generate_flowchart(prompt: str):
55
  """
56
  Generates a flowchart using the LOCALLY loaded model. No API token is needed.
57
  """
 
 
 
 
58
  if not prompt:
59
  return create_placeholder_image("Please enter a prompt to generate a flowchart."), None
60
 
61
  try:
62
- # 1. Prepare the full prompt and tokenize it
63
  full_prompt = SYSTEM_PROMPT_TEMPLATE.format(user_prompt=prompt)
64
  inputs = tokenizer(full_prompt, return_tensors="pt").input_ids.to(DEVICE)
65
 
@@ -67,66 +79,51 @@ def generate_flowchart(prompt: str):
67
  outputs = model.generate(inputs, max_new_tokens=1024, temperature=0.8, do_sample=True)
68
  dot_code = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
69
 
70
- # 3. Clean up the generated code
71
  if dot_code.startswith("```dot"): dot_code = dot_code[len("```dot"):].strip()
72
  if dot_code.startswith("```"): dot_code = dot_code[len("```"):].strip()
73
  if dot_code.endswith("```"): dot_code = dot_code[:-len("```")].strip()
 
 
74
 
75
- # 4. Render the DOT code using Graphviz
76
  graph = graphviz.Source(dot_code)
77
  output_path = graph.render(os.path.join("outputs", "flowchart"), format='png', cleanup=True)
78
 
79
  return output_path, gr.update(value=output_path, visible=True)
80
 
81
  except Exception as e:
82
- print(f"An error occurred: {e}")
83
  error_message = f"An error occurred during generation.\nThe AI might have produced invalid flowchart code, or another issue occurred.\n\nDetails: {str(e)}"
84
  return create_placeholder_image(error_message), gr.update(visible=False)
85
 
86
-
87
  # --- 4. GRADIO UI ---
88
- # (The Gradio UI block remains mostly unchanged, just removing the token logic)
89
- css = """
90
- footer {display: none !important}
91
- .gradio-container {background-color: #f8f9fa}
92
- #status_display {text-align: center; color: #888;}
93
- """
94
-
95
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
96
- gr.Markdown("# AI Flowchart Generator")
97
- gr.Markdown(
98
- "Our AI Flowchart Generator allows you to create detailed flowcharts instantly. This version runs a self-contained model directly in this Space."
99
- )
100
  with gr.Group():
101
  with gr.Row(equal_height=False):
102
  with gr.Column(scale=1):
103
- prompt_input = gr.Textbox(
104
- lines=10,
105
- placeholder="e.g., Explain the process of making a cup of tea",
106
- label="Enter your process description here"
107
- )
108
- with gr.Row():
109
- generate_btn = gr.Button("✨ Generate", variant="primary")
110
- status_display = gr.Markdown("", elem_id="status_display")
111
  with gr.Column(scale=1):
112
- output_image = gr.Image(
113
- label="Generated Flowchart", type="filepath", interactive=False,
114
- value=create_placeholder_image(), height=600, show_label=False
115
- )
116
- download_btn = gr.DownloadButton(
117
- "⬇️ Download", variant="primary", visible=False,
118
- )
119
 
120
- def on_generate_click(prompt, progress=gr.Progress(track_tqdm=True)):
121
- yield (gr.update(interactive=False), gr.update(visible=False), create_placeholder_image("🧠 Thinking... Please wait."), "Generating...")
122
- # Note: The 'hf_token' is no longer passed here
 
123
  img_path, download_btn_update = generate_flowchart(prompt)
124
- yield (gr.update(interactive=True), download_btn_update, img_path, "")
 
125
 
126
  generate_btn.click(
127
  fn=on_generate_click,
128
  inputs=[prompt_input],
129
- outputs=[generate_btn, download_btn, output_image, status_display]
130
  )
131
 
132
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import graphviz
3
  import os
4
+ import torch
5
  from transformers import T5Tokenizer, T5ForConditionalGeneration
6
  from PIL import Image, ImageDraw, ImageFont
 
7
 
8
+ # --- 1. SETUP: LOAD THE MODEL LOCALLY ---
9
+ # We are no longer using the Inference API. This loads the model into the Space's memory.
10
+ # This happens only once when the app starts up.
11
+
12
  print("--- Initializing Local Model ---")
13
+ # This model is small enough to run on a free CPU Space and is excellent at following instructions.
14
+ MODEL_ID = "google/flan-t5-base"
15
 
16
+ # Determine the device. Free Spaces run on CPU.
17
+ DEVICE = "cpu"
18
  print(f"--- Using device: {DEVICE} ---")
19
 
20
+ # Load the model's tokenizer and the model itself.
21
+ # This might take a few minutes the first time the Space starts.
22
+ try:
23
+ tokenizer = T5Tokenizer.from_pretrained(MODEL_ID)
24
+ model = T5ForConditionalGeneration.from_pretrained(MODEL_ID).to(DEVICE)
25
+ print(f"--- Model {MODEL_ID} loaded successfully ---")
26
+ except Exception as e:
27
+ print(f"Error loading model: {e}")
28
+ # Handle model loading failure gracefully in the UI later
29
+ tokenizer, model = None, None
30
+
31
+ # --- 2. DEFINE THE PROMPT TEMPLATE ---
32
+ # A structured prompt is key to getting good results.
33
  SYSTEM_PROMPT_TEMPLATE = """Task: Generate a flowchart description in the Graphviz DOT language based on the following text.
34
  Your response MUST be ONLY the Graphviz DOT language source code for a directed graph (digraph).
35
  - The graph should be top-to-bottom (`rankdir=TB`).
 
41
 
42
  DOT Language Code:"""
43
 
44
+
45
+ # --- 3. HELPER AND CORE FUNCTIONS ---
46
  def create_placeholder_image(text="Flowchart will be generated here", size=(600, 800), path="placeholder.png"):
47
+ """Creates a placeholder or error image with text."""
48
  try:
49
  img = Image.new('RGB', size, color=(255, 255, 255))
50
  draw = ImageDraw.Draw(img)
 
56
  draw.text(position, text, fill=(200, 200, 200), font=font)
57
  img.save(path)
58
  return path
59
+ except Exception:
60
+ return None
61
 
 
 
62
  def generate_flowchart(prompt: str):
63
  """
64
  Generates a flowchart using the LOCALLY loaded model. No API token is needed.
65
  """
66
+ # Check if the model failed to load on startup
67
+ if not model or not tokenizer:
68
+ return create_placeholder_image("Error: AI Model failed to load on startup. Please check the logs."), None
69
+
70
  if not prompt:
71
  return create_placeholder_image("Please enter a prompt to generate a flowchart."), None
72
 
73
  try:
74
+ # 1. Prepare the full prompt and convert it to tokens
75
  full_prompt = SYSTEM_PROMPT_TEMPLATE.format(user_prompt=prompt)
76
  inputs = tokenizer(full_prompt, return_tensors="pt").input_ids.to(DEVICE)
77
 
 
79
  outputs = model.generate(inputs, max_new_tokens=1024, temperature=0.8, do_sample=True)
80
  dot_code = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
81
 
82
+ # 3. Clean up the generated DOT code
83
  if dot_code.startswith("```dot"): dot_code = dot_code[len("```dot"):].strip()
84
  if dot_code.startswith("```"): dot_code = dot_code[len("```"):].strip()
85
  if dot_code.endswith("```"): dot_code = dot_code[:-len("```")].strip()
86
+ if not dot_code.startswith("digraph"): dot_code = "digraph G {\n" + dot_code + "\n}"
87
+
88
 
89
+ # 4. Render the DOT code into an image using Graphviz
90
  graph = graphviz.Source(dot_code)
91
  output_path = graph.render(os.path.join("outputs", "flowchart"), format='png', cleanup=True)
92
 
93
  return output_path, gr.update(value=output_path, visible=True)
94
 
95
  except Exception as e:
96
+ print(f"An error occurred during generation: {e}")
97
  error_message = f"An error occurred during generation.\nThe AI might have produced invalid flowchart code, or another issue occurred.\n\nDetails: {str(e)}"
98
  return create_placeholder_image(error_message), gr.update(visible=False)
99
 
 
100
  # --- 4. GRADIO UI ---
101
+ css = "footer {display: none !important} .gradio-container {background-color: #f8f9fa}"
 
 
 
 
 
 
102
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
103
+ gr.Markdown("# AI Flowchart Generator (Self-Contained)")
104
+ gr.Markdown("This version runs a free, open-source model directly in this Space. No API keys or monthly limits!")
105
+
 
106
  with gr.Group():
107
  with gr.Row(equal_height=False):
108
  with gr.Column(scale=1):
109
+ prompt_input = gr.Textbox(lines=10, placeholder="e.g., Explain the process of making a cup of tea", label="Enter your process description here")
110
+ generate_btn = gr.Button("✨ Generate", variant="primary")
 
 
 
 
 
 
111
  with gr.Column(scale=1):
112
+ output_image = gr.Image(label="Generated Flowchart", type="filepath", interactive=False, value=create_placeholder_image(), height=600, show_label=False)
113
+ download_btn = gr.DownloadButton("⬇️ Download", variant="primary", visible=False)
 
 
 
 
 
114
 
115
+ def on_generate_click(prompt):
116
+ # Provide user feedback that generation is in progress
117
+ yield (gr.update(interactive=False), gr.update(visible=False), create_placeholder_image("🧠 Thinking... Please wait.\n(First generation can be slow)"))
118
+ # The generate_flowchart function no longer needs a token
119
  img_path, download_btn_update = generate_flowchart(prompt)
120
+ # Update UI with the result
121
+ yield (gr.update(interactive=True), download_btn_update, img_path)
122
 
123
  generate_btn.click(
124
  fn=on_generate_click,
125
  inputs=[prompt_input],
126
+ outputs=[generate_btn, download_btn, output_image]
127
  )
128
 
129
  if __name__ == "__main__":