paulcalzada commited on
Commit
1857a25
·
1 Parent(s): 05e3f60

changed app to integrate verilogagent

Browse files
Files changed (1) hide show
  1. app.py +308 -4
app.py CHANGED
@@ -1,7 +1,311 @@
 
 
 
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from pathlib import Path
4
  import gradio as gr
5
 
6
+ from huggingface_hub import snapshot_download
7
+ from langchain_community.vectorstores import FAISS
8
+ from langchain_huggingface import HuggingFaceEmbeddings
9
 
10
+ # ---- Your class from the prompt (unchanged except 1 tweak: import path safety) ----
11
+ import re
12
+ from pathlib import Path as _Path
13
+ try:
14
+ from openai import OpenAI
15
+ except ImportError:
16
+ OpenAI = None
17
+
18
+ class VerilogAgent:
19
+ """
20
+ A self-contained agent for generating Verilog code using API-based LLMs (e.g., GPT)
21
+ and a RAG pipeline.
22
+ """
23
+ def __init__(self, model_id, embedding_id, faiss_index_path, api_key):
24
+ self.model_id = model_id
25
+ self.embedding_id = embedding_id
26
+ self.faiss_index_path = faiss_index_path
27
+ self.api_key = api_key
28
+
29
+ print(f"[INFO] Initializing VerilogAgent for model: {self.model_id}")
30
+ self._load_dependencies()
31
+
32
+ def _load_dependencies(self):
33
+ print(f"[INFO] Loading embedding model '{self.embedding_id}'...")
34
+ embedding = HuggingFaceEmbeddings(model_name=self.embedding_id)
35
+
36
+ print(f"[INFO] Loading FAISS vector store from '{self.faiss_index_path}'...")
37
+ if not _Path(self.faiss_index_path).exists():
38
+ raise FileNotFoundError(f"FAISS index directory not found at {self.faiss_index_path}.")
39
+ self.vectorstore = FAISS.load_local(self.faiss_index_path, embedding, allow_dangerous_deserialization=True)
40
+
41
+ if not OpenAI:
42
+ raise ImportError("OpenAI library is not installed. Please add 'openai' to requirements.txt.")
43
+ if not self.api_key:
44
+ raise ValueError("OpenAI API key is required.")
45
+ self.client = OpenAI(api_key=self.api_key)
46
+ print("[INFO] OpenAI client initialized.")
47
+ print("[INFO] VerilogAgent initialized successfully.")
48
+
49
+ def _build_prompt_messages_rag(self, query: str, docs: list = None) -> list:
50
+ context_section = ""
51
+ if docs:
52
+ context = "\n\n".join([doc.page_content for doc in docs])
53
+ context_section = f"""
54
+ CONTEXT EXAMPLES:
55
+ ```verilog
56
+ {context}
57
+ ```"""
58
+
59
+ system_prompt = """You are an expert Verilog code generation assistant.
60
+ TASK: Generate fully implemented, syntactically correct Verilog code in response to the user request.
61
+
62
+ INSTRUCTIONS:
63
+ 1. Analyze the user request and the provided context examples to determine the required modules and logic.
64
+ 2. The context provides examples of valid, complete Verilog modules.
65
+ 3. Implement all required modules. Every `module ... endmodule` block must be complete.
66
+ 4. Do not leave logic empty or use placeholders like `// your code here`.
67
+ 5. Your entire response MUST be only the Verilog code, wrapped in a single `verilog` markdown block. Do not include any natural language explanations.
68
+ 6. Use only Verilog-2005 syntax. Do not use SystemVerilog constructs (e.g., `logic`, `always_ff`).
69
+ 7. Ensure all identifiers are declared before use and vector ranges are ordered [MSB:LSB].
70
+ """
71
+ user_prompt = f'''TASK: Generate a fully implemented, syntactically correct Verilog module named 'TopModule'. This name is a strict requirement.
72
+ {context_section}
73
+
74
+ USER REQUEST:
75
+ """{query}"""
76
+
77
+ OUTPUT:
78
+ Generate the complete Verilog code.
79
+ '''
80
+ return [
81
+ {"role": "system", "content": system_prompt},
82
+ {"role": "user", "content": user_prompt}
83
+ ]
84
+
85
+ def _build_prompt_messages_baseline(self, query: str, docs: list = None) -> list:
86
+ context_section = ""
87
+ if docs:
88
+ context = "\n\n".join([doc.page_content for doc in docs])
89
+ context_section = f"""
90
+ CONTEXT EXAMPLES:
91
+ ```verilog
92
+ {context}
93
+ ```"""
94
+ system_prompt = "You are an expert Verilog code generation assistant."
95
+ user_prompt = f'''TASK: Generate a fully implemented, syntactically correct Verilog module named 'TopModule'. This name is a strict requirement.
96
+ {context_section}
97
+
98
+ USER REQUEST:
99
+ """{query}"""
100
+
101
+ OUTPUT:
102
+ Generate the complete Verilog code.
103
+ '''
104
+ return [
105
+ {"role": "system", "content": system_prompt},
106
+ {"role": "user", "content": user_prompt}
107
+ ]
108
+
109
+ def _extract_verilog_code(self, text: str) -> str:
110
+ verilog_pattern = re.compile(r"```(?:verilog\s*)?(.*?)\s*```", re.DOTALL)
111
+ match = verilog_pattern.search(text)
112
+ if match:
113
+ return match.group(1).strip()
114
+ module_pattern = re.compile(r"(module.*?endmodule)", re.DOTALL)
115
+ match = module_pattern.search(text)
116
+ if match:
117
+ return match.group(1).strip()
118
+ return text.strip()
119
+
120
+ def _call_api(self, messages: list, generation_params: dict) -> str:
121
+ try:
122
+ api_params = {
123
+ "model": self.model_id,
124
+ "messages": messages,
125
+ "max_tokens": generation_params.get("max_new_tokens"),
126
+ "temperature": generation_params.get("temperature"),
127
+ "top_p": generation_params.get("top_p")
128
+ }
129
+ if "gpt-5" in self.model_id and "verbosity" in generation_params:
130
+ api_params["verbosity"] = generation_params["verbosity"]
131
+
132
+ completion = self.client.chat.completions.create(**api_params)
133
+ return self._extract_verilog_code(completion.choices[0].message.content)
134
+ except Exception as e:
135
+ print(f"[ERROR] Code generation failed for model {self.model_id}: {e}")
136
+ return f"// ERROR: Generation failed. Details: {e}"
137
+
138
+ def generate_with_context(self, spec: str, docs_with_scores: list, generation_params: dict) -> str:
139
+ relevant_docs = [doc for doc, score in docs_with_scores]
140
+ messages = self._build_prompt_messages_rag(spec, relevant_docs)
141
+ return self._call_api(messages, generation_params)
142
+
143
+ def generate_baseline(self, spec: str, generation_params: dict) -> str:
144
+ messages = self._build_prompt_messages_baseline(spec, docs=[])
145
+ return self._call_api(messages, generation_params)
146
+
147
+
148
+ # --------------------------- Space wiring below ---------------------------
149
+
150
+ # Where we’ll place the FAISS index on disk after downloading from your private dataset:
151
+ CACHE_DIR = Path("/data/faiss_index") # Spaces ephemeral storage
152
+ CACHE_DIR.mkdir(parents=True, exist_ok=True)
153
+
154
+ # Env vars you’ll set in the Space “Settings → Repository secrets”
155
+ HF_TOKEN = os.getenv("HF_TOKEN") # personal access token with read permission
156
+ PRIVATE_DATASET_ID = os.getenv("PRIVATE_DATASET_ID") # e.g. "yourname/verilog-faiss-index"
157
+ INDEX_SUBDIR = os.getenv("INDEX_SUBDIR", "faiss_index") # optional subdir within the dataset snapshot
158
+ EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
159
+
160
+ def ensure_index_downloaded() -> Path:
161
+ """
162
+ Downloads your private dataset (FAISS index + artifacts) once per container.
163
+ Avoids committing large binaries to the public Space repo.
164
+ """
165
+ target = CACHE_DIR / INDEX_SUBDIR
166
+ if target.exists() and any(target.iterdir()):
167
+ print(f"[INFO] Using cached FAISS index at {target}")
168
+ return target
169
+
170
+ if not HF_TOKEN:
171
+ raise RuntimeError("Missing HF_TOKEN secret. Add it in the Space settings.")
172
+ if not PRIVATE_DATASET_ID:
173
+ raise RuntimeError("Missing PRIVATE_DATASET_ID secret (e.g., 'user/private-faiss').")
174
+
175
+ print(f"[INFO] Downloading private dataset: {PRIVATE_DATASET_ID}")
176
+ snapshot_path = snapshot_download(
177
+ repo_id=PRIVATE_DATASET_ID,
178
+ repo_type="dataset",
179
+ token=HF_TOKEN,
180
+ local_dir=str(CACHE_DIR),
181
+ local_dir_use_symlinks=False, # safer for FAISS
182
+ )
183
+ # If your index files live under a folder inside the dataset, move/point to it
184
+ candidate = Path(snapshot_path) / INDEX_SUBDIR
185
+ if candidate.exists():
186
+ print(f"[INFO] Found index subdir at {candidate}")
187
+ return candidate
188
+
189
+ # Otherwise assume snapshot root contains the index
190
+ print(f"[WARN] INDEX_SUBDIR='{INDEX_SUBDIR}' not found; using snapshot root.")
191
+ return Path(snapshot_path)
192
+
193
+ # Keep a lightweight global cache so we don’t reload embeddings on every click
194
+ _VECTORSTORE_PATH = None
195
+
196
+ def get_vectorstore_path() -> Path:
197
+ global _VECTORSTORE_PATH
198
+ if _VECTORSTORE_PATH is None:
199
+ _VECTORSTORE_PATH = ensure_index_downloaded()
200
+ return _VECTORSTORE_PATH
201
+
202
+ def run_generation(spec, use_rag, top_k, model_choice, api_key, temperature, top_p, max_new_tokens):
203
+ if not spec or not api_key:
204
+ return "// Please provide a design specification and your API key.", "", []
205
+
206
+ # Prepare agent
207
+ faiss_path = get_vectorstore_path()
208
+ try:
209
+ agent = VerilogAgent(
210
+ model_id=model_choice,
211
+ embedding_id=EMBEDDING_MODEL,
212
+ faiss_index_path=str(faiss_path),
213
+ api_key=api_key.strip()
214
+ )
215
+ except Exception as e:
216
+ return f"// Initialization error: {e}", "", []
217
+
218
+ # Retrieval (if enabled)
219
+ docs_with_scores = []
220
+ retrieved_preview = []
221
+ if use_rag:
222
+ try:
223
+ # similarity_search_with_score returns list[(Document, score)]
224
+ docs_with_scores = agent.vectorstore.similarity_search_with_score(spec, k=top_k)
225
+ for doc, score in docs_with_scores:
226
+ src = doc.metadata.get("source_file", doc.metadata.get("module", "unknown"))
227
+ retrieved_preview.append(f"{src} | score={score:.4f}")
228
+ except Exception as e:
229
+ return f"// Retrieval error: {e}", "", []
230
+
231
+ # Call model
232
+ gen_params = {
233
+ "temperature": float(temperature),
234
+ "top_p": float(top_p),
235
+ "max_new_tokens": int(max_new_tokens),
236
+ }
237
+
238
+ if use_rag:
239
+ code = agent.generate_with_context(spec, docs_with_scores, gen_params)
240
+ else:
241
+ code = agent.generate_baseline(spec, gen_params)
242
+
243
+ # Clean presentation
244
+ verilog_block = code.strip()
245
+ # Show the first few chars of the retrieved examples (for transparency)
246
+ return verilog_block, ("\n".join(retrieved_preview) if retrieved_preview else ""), [d[0].page_content for d in docs_with_scores]
247
+
248
+
249
+ with gr.Blocks(title="DeepRAG for RTL (Model-Agnostic)") as demo:
250
+ gr.Markdown("## DeepRAG for RTL Code Generation — Model-Agnostic (Bring Your Own API Key)")
251
+
252
+ with gr.Row():
253
+ with gr.Column(scale=2):
254
+ spec = gr.Textbox(
255
+ label="Design Specification (natural language or I/O contract)",
256
+ placeholder="e.g., 8-bit UART transmitter with baud rate generator ...",
257
+ lines=10
258
+ )
259
+ with gr.Row():
260
+ use_rag = gr.Checkbox(value=True, label="Use Retrieval (RAG)")
261
+ top_k = gr.Slider(1, 10, value=3, step=1, label="Top-K retrieved examples")
262
+
263
+ with gr.Row():
264
+ model_choice = gr.Dropdown(
265
+ choices=[
266
+ "gpt-4o",
267
+ "gpt-4o-mini",
268
+ "gpt-4.1",
269
+ "gpt-5", # hypothetical/future-ready
270
+ "gpt-5-mini" # hypothetical/future-ready
271
+ ],
272
+ value="gpt-4o",
273
+ label="Model"
274
+ )
275
+ api_key = gr.Textbox(label="OpenAI API Key", type="password", placeholder="sk-...")
276
+
277
+ with gr.Accordion("Generation Settings", open=False):
278
+ temperature = gr.Slider(0.0, 1.5, value=0.2, step=0.05, label="Temperature")
279
+ top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
280
+ max_new_tokens = gr.Slider(128, 4096, value=768, step=64, label="Max tokens")
281
+
282
+ run_btn = gr.Button("Generate Verilog", variant="primary")
283
+
284
+ # Right side: code output + retrieval transparency
285
+ with gr.Column(scale=3):
286
+ gr.Markdown("**Output**")
287
+ out_code = gr.Code(
288
+ label="Generated Verilog (copy-ready)",
289
+ language="verilog",
290
+ interactive=False,
291
+ show_copy_button=True,
292
+ lines=28
293
+ )
294
+ with gr.Tab("Retrieved Items (names + scores)"):
295
+ retrieved_list = gr.Textbox(
296
+ label="Retriever summary",
297
+ lines=8,
298
+ interactive=False
299
+ )
300
+ with gr.Tab("Preview of Retrieved Context (raw)"):
301
+ # shows the raw text of retrieved docs for transparency (not downloadable)
302
+ retrieved_raw = gr.HighlightedText(label="(first K documents)", combine_adjacent=True)
303
+
304
+ run_btn.click(
305
+ fn=run_generation,
306
+ inputs=[spec, use_rag, top_k, model_choice, api_key, temperature, top_p, max_new_tokens],
307
+ outputs=[out_code, retrieved_list, retrieved_raw]
308
+ )
309
+
310
+ if __name__ == "__main__":
311
+ demo.launch()