Bahar110 commited on
Commit
a002cf8
·
verified ·
1 Parent(s): cca29db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -254
app.py CHANGED
@@ -1,286 +1,67 @@
1
- import ast
2
- import logging
3
  import os
4
  import re
5
- from datetime import datetime
6
-
7
  import gradio as gr
8
  from groq import Groq
9
 
10
- # ---------------------------------------------------------------------------
11
- # Config
12
- # ---------------------------------------------------------------------------
13
  GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
14
- MODEL = "llama-3.3-70b-versatile"
15
- TEMPERATURE = 0.1
16
- MAX_GOAL_CHARS = 4000
17
-
18
- logging.basicConfig(level=logging.INFO)
19
- logger = logging.getLogger(__name__)
20
-
21
- # Netmiko device_type -> NAPALM driver (when user picks NAPALM)
22
- NAPALM_DRIVER_MAP = {
23
- "cisco_ios": "ios",
24
- "cisco_nxos": "nxos",
25
- "arista_eos": "eos",
26
- "juniper_junos": "junos",
27
- }
28
-
29
- SYSTEM_PROMPT = """You are a Senior Network Automation Architect.
30
- Your task is to generate production-ready Python automation scripts.
31
-
32
- Rules:
33
- - Output ONLY raw Python code. No markdown fences, no ```python, no prose before or after.
34
- - Use Python 3.10+ syntax.
35
- - Never hardcode real passwords, API keys, or production IPs. Use placeholders and environment variables.
36
- - Include try/except for connection failures, sensible timeouts, and basic logging where appropriate.
37
- - Add brief inline comments only for non-obvious steps.
38
- - Include if __name__ == "__main__": with argparse or clear placeholders for hosts, credentials, and paths.
39
- - For Netmiko: use ConnectHandler with device_type exactly as provided in the user message.
40
- - For NAPALM: use the napalm driver name provided in the user message.
41
- - For Nornir: show inventory YAML structure or a minimal InitNornir example with placeholders.
42
- - Scripts must be reviewable in a lab before production use."""
43
-
44
- CUSTOM_CSS = """
45
- .gradio-container { font-family: 'Inter', system-ui, sans-serif; }
46
- .header-box {
47
- text-align: center;
48
- padding: 20px;
49
- background: linear-gradient(90deg, #1e3a8a, #3b82f6);
50
- color: white;
51
- border-radius: 10px;
52
- margin-bottom: 16px;
53
- }
54
- .header-box h1 { margin: 0; font-size: 28px; font-weight: bold; color: white !important; }
55
- .header-box p { margin: 5px 0 0 0; opacity: 0.9; }
56
- .disclaimer {
57
- font-size: 0.85rem;
58
- color: #b45309;
59
- background: #fffbeb;
60
- border: 1px solid #fcd34d;
61
- border-radius: 8px;
62
- padding: 10px 14px;
63
- margin-bottom: 12px;
64
- }
65
- """
66
-
67
-
68
- def strip_code_fences(text: str) -> str:
69
- """Remove ```python ... ``` wrappers if the model ignores instructions."""
70
- if not text:
71
- return ""
72
- text = text.strip()
73
- pattern = r"^```(?:python)?\s*\n?(.*?)\n?```\s*$"
74
- match = re.match(pattern, text, re.DOTALL | re.IGNORECASE)
75
- return match.group(1).strip() if match else text
76
-
77
-
78
- def validate_syntax(code: str) -> str:
79
- """Return status line for UI (does not block output)."""
80
- try:
81
- ast.parse(code)
82
- return "Syntax check: OK"
83
- except SyntaxError as e:
84
- return f"Syntax check: possible issue — {e.msg} (line {e.lineno})"
85
-
86
-
87
- def build_user_prompt(user_goal: str, library: str, device_type: str) -> str:
88
- napalm_driver = NAPALM_DRIVER_MAP.get(device_type, device_type)
89
- return f"""Library: {library}
90
- Netmiko device_type (if using Netmiko): {device_type}
91
- NAPALM driver (if using NAPALM): {napalm_driver}
92
-
93
- Task:
94
- {user_goal.strip()}
95
-
96
- Requirements:
97
- - Match idioms and APIs for {library}.
98
- - Use placeholders for inventory (hostnames, usernames, paths).
99
- - If the task involves many devices, use concurrent execution safely with clear limits.
100
- - End with runnable structure and documented placeholders."""
101
-
102
 
103
- def generate_network_script(
104
- user_goal: str,
105
- library: str,
106
- device_type: str,
107
- progress: gr.Progress = gr.Progress(),
108
- ) -> tuple[str, str, str]:
109
- """
110
- Returns: (code, status_message, download_filename_suggestion)
111
- """
112
- progress(0, desc="Validating input…")
113
 
 
114
  if not GROQ_API_KEY:
115
- err = (
116
- "# Error\n"
117
- "GROQ_API_KEY is missing. Add it in Space Settings → "
118
- "Variables and secrets."
119
- )
120
- return err, "Missing API key", ""
121
-
122
  goal = (user_goal or "").strip()
123
  if not goal:
124
- return "# Error\nPlease describe the task.", "Empty task", ""
125
- if len(goal) > MAX_GOAL_CHARS:
126
- return (
127
- f"# Error\nTask description exceeds {MAX_GOAL_CHARS} characters.",
128
- "Input too long",
129
- "",
130
- )
131
 
132
- progress(0.2, desc="Calling Groq…")
133
- client = Groq(api_key=GROQ_API_KEY)
134
- user_prompt = build_user_prompt(goal, library, device_type)
135
 
136
  try:
137
- chat_completion = client.chat.completions.create(
 
 
138
  messages=[
139
- {"role": "system", "content": SYSTEM_PROMPT},
140
  {"role": "user", "content": user_prompt},
141
  ],
142
- model=MODEL,
143
- temperature=TEMPERATURE,
144
- )
145
- raw = chat_completion.choices[0].message.content or ""
146
- code = strip_code_fences(raw)
147
- progress(0.9, desc="Checking syntax…")
148
- status = validate_syntax(code)
149
- progress(1.0, desc="Done")
150
- filename = (
151
- f"autonet_{library.lower()}_{device_type}_"
152
- f"{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}.py"
153
- )
154
- return code, status, filename
155
-
156
- except Exception:
157
- logger.exception("Groq generation failed")
158
- return (
159
- "# Error\n"
160
- "Generation failed. Check Space logs or try again in a moment.",
161
- "Generation failed",
162
- "",
163
  )
164
-
165
-
166
- def on_generate(user_goal, library, device_type, progress=gr.Progress()):
167
- code, status, filename = generate_network_script(
168
- user_goal, library, device_type, progress
169
- )
170
- return code, status, code, filename
171
-
172
-
173
- # Example prompts for first-time users
174
- EXAMPLES = [
175
- [
176
- "Log into a list of Cisco switches via SSH, backup running-config, "
177
- "and save each file as {hostname}_{date}_running.cfg in ./backups/",
178
- "Netmiko",
179
- "cisco_ios",
180
- ],
181
- [
182
- "Connect to three Arista switches, collect interface status and "
183
- "print any interface that is admin down but operationally up.",
184
- "NAPALM",
185
- "arista_eos",
186
- ],
187
- [
188
- "Using an inventory file, run 'show version' on all Juniper devices "
189
- "and save output to logs/{host}.txt",
190
- "Nornir",
191
- "juniper_junos",
192
- ],
193
- ]
194
-
195
- HEADER_HTML = """
196
- <div class="header-box">
197
- <h1>⚙️ AutoNet Architect</h1>
198
- <p>AI-Powered Python Script Generator for Network Engineers</p>
199
- </div>
200
- """
201
-
202
- DISCLAIMER_HTML = """
203
- <div class="disclaimer">
204
- <strong>Review before you run.</strong> Generated code may be wrong or unsafe.
205
- Test in a lab first. Never paste production passwords into this form.
206
- You are responsible for anything executed on your network.
207
- </div>
208
  """
209
 
210
-
211
- with gr.Blocks(
212
- theme=gr.themes.Soft(primary_hue="blue", secondary_hue="indigo"),
213
- css=CUSTOM_CSS,
214
- title="AutoNet Architect",
215
- ) as app:
216
-
217
- gr.HTML(HEADER_HTML)
218
- gr.HTML(DISCLAIMER_HTML)
219
-
220
  with gr.Row():
221
  with gr.Column(scale=1):
222
- library = gr.Dropdown(
223
- choices=["Netmiko", "NAPALM", "Nornir"],
224
- value="Netmiko",
225
- label="1. Select Python Library",
226
- info="Which automation framework do you want to use?",
227
- )
228
  device_type = gr.Dropdown(
229
- choices=["cisco_ios", "cisco_nxos", "arista_eos", "juniper_junos"],
230
  value="cisco_ios",
231
- label="2. Select Device OS",
232
- info="Netmiko device_type; NAPALM driver is mapped automatically.",
233
- )
234
- user_goal = gr.Textbox(
235
- lines=6,
236
- label="3. Describe the Task",
237
- placeholder=(
238
- "e.g., Log into 50 switches, backup running-config, and save "
239
- "them as text files with hostname and date."
240
- ),
241
  )
242
-
243
- gr.Examples(
244
- examples=EXAMPLES,
245
- inputs=[user_goal, library, device_type],
246
- label="Example tasks",
247
- )
248
-
249
- generate_btn = gr.Button(
250
- "✨ Generate Code",
251
- variant="primary",
252
- size="lg",
253
- )
254
-
255
  with gr.Column(scale=2):
256
- status_md = gr.Markdown("*Ready.*")
257
- output_code = gr.Code(
258
- label="Production-Ready Python Script",
259
- language="python",
260
- lines=28,
261
- interactive=True,
262
- )
263
- with gr.Row():
264
- download_btn = gr.DownloadButton(
265
- "⬇️ Download .py",
266
- variant="secondary",
267
- )
268
-
269
- # Hidden state for download filename content (same as code)
270
- download_file = gr.File(visible=False)
271
 
272
  generate_btn.click(
273
- fn=on_generate,
274
  inputs=[user_goal, library, device_type],
275
- outputs=[output_code, status_md, download_btn, download_btn],
276
- show_progress="full",
277
  )
278
 
279
-
280
- # Hugging Face Spaces
281
  if __name__ == "__main__":
282
- app.queue(default_concurrency_limit=2)
283
- app.launch(
284
- server_name="0.0.0.0",
285
- server_port=int(os.environ.get("PORT", 7860)),
286
- )
 
 
 
1
  import os
2
  import re
 
 
3
  import gradio as gr
4
  from groq import Groq
5
 
 
 
 
6
  GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ def strip_code_fences(text):
9
+ text = (text or "").strip()
10
+ m = re.match(r"^```(?:python)?\s*\n?(.*?)\n?```\s*$", text, re.DOTALL | re.IGNORECASE)
11
+ return m.group(1).strip() if m else text
 
 
 
 
 
 
12
 
13
+ def generate_network_script(user_goal, library, device_type):
14
  if not GROQ_API_KEY:
15
+ return "# Error\nGROQ_API_KEY is missing. Add it in Space Settings → Variables and secrets."
 
 
 
 
 
 
16
  goal = (user_goal or "").strip()
17
  if not goal:
18
+ return "# Error\nPlease describe the task."
 
 
 
 
 
 
19
 
20
+ system_prompt = """You are a Senior Network Automation Architect.
21
+ Output ONLY raw Python code. No markdown fences. No extra text."""
22
+ user_prompt = f"Library: {library}\nDevice: {device_type}\nTask: {goal}"
23
 
24
  try:
25
+ client = Groq(api_key=GROQ_API_KEY)
26
+ r = client.chat.completions.create(
27
+ model="llama-3.3-70b-versatile",
28
  messages=[
29
+ {"role": "system", "content": system_prompt},
30
  {"role": "user", "content": user_prompt},
31
  ],
32
+ temperature=0.1,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  )
34
+ return strip_code_fences(r.choices[0].message.content)
35
+ except Exception as e:
36
+ return f"# Error\n{str(e)}"
37
+
38
+ custom_css = """
39
+ .gradio-container { font-family: 'Inter', sans-serif; }
40
+ .header-box { text-align: center; padding: 20px; background: linear-gradient(90deg, #1e3a8a, #3b82f6);
41
+ color: white; border-radius: 10px; margin-bottom: 20px; }
42
+ .header-box h1 { margin: 0; font-size: 28px; color: white !important; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  """
44
 
45
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="indigo"), css=custom_css) as app:
46
+ gr.HTML('<div class="header-box"><h1>⚙️ AutoNet Architect</h1><p>AI script generator</p></div>')
 
 
 
 
 
 
 
 
47
  with gr.Row():
48
  with gr.Column(scale=1):
49
+ library = gr.Dropdown(["Netmiko", "NAPALM", "Nornir"], value="Netmiko", label="Library")
 
 
 
 
 
50
  device_type = gr.Dropdown(
51
+ ["cisco_ios", "cisco_nxos", "arista_eos", "juniper_junos"],
52
  value="cisco_ios",
53
+ label="Device OS",
 
 
 
 
 
 
 
 
 
54
  )
55
+ user_goal = gr.Textbox(lines=5, label="Describe the task")
56
+ generate_btn = gr.Button("✨ Generate Code", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
57
  with gr.Column(scale=2):
58
+ output_code = gr.Code(label="Python Script", language="python", interactive=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  generate_btn.click(
61
+ fn=generate_network_script,
62
  inputs=[user_goal, library, device_type],
63
+ outputs=output_code,
 
64
  )
65
 
 
 
66
  if __name__ == "__main__":
67
+ app.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))