paulcalzada commited on
Commit
e1cd8a1
·
1 Parent(s): c14cd0b

offloaded main agent to private repo

Browse files
Files changed (1) hide show
  1. app.py +41 -233
app.py CHANGED
@@ -1,196 +1,38 @@
1
  import os
2
- import shutil
3
  from pathlib import Path
4
  import gradio as gr
 
 
5
 
6
- from huggingface_hub import snapshot_download
7
- from langchain_community.vectorstores import FAISS
8
- from langchain_huggingface import HuggingFaceEmbeddings
 
9
 
10
- # ---- Your class from the prompt (unchanged except 1 tweak: import path safety) ----
11
- import re
12
- from pathlib import Path as _Path
13
  try:
14
- from openai import OpenAI
15
- except ImportError:
16
- OpenAI = None
17
-
18
- class VerilogAgent:
19
- """
20
- A self-contained agent for generating Verilog code using API-based LLMs (e.g., GPT)
21
- and a RAG pipeline.
22
- """
23
- def __init__(self, model_id, embedding_id, faiss_index_path, api_key):
24
- self.model_id = model_id
25
- self.embedding_id = embedding_id
26
- self.faiss_index_path = faiss_index_path
27
- self.api_key = api_key
28
-
29
- print(f"[INFO] Initializing VerilogAgent for model: {self.model_id}")
30
- self._load_dependencies()
31
-
32
- def _load_dependencies(self):
33
- print(f"[INFO] Loading embedding model '{self.embedding_id}'...")
34
- embedding = HuggingFaceEmbeddings(model_name=self.embedding_id)
35
-
36
- print(f"[INFO] Loading FAISS vector store from '{self.faiss_index_path}'...")
37
- if not _Path(self.faiss_index_path).exists():
38
- raise FileNotFoundError(f"FAISS index directory not found at {self.faiss_index_path}.")
39
- self.vectorstore = FAISS.load_local(self.faiss_index_path, embedding, allow_dangerous_deserialization=True)
40
-
41
- if not OpenAI:
42
- raise ImportError("OpenAI library is not installed. Please add 'openai' to requirements.txt.")
43
- if not self.api_key:
44
- raise ValueError("OpenAI API key is required.")
45
- self.client = OpenAI(api_key=self.api_key)
46
- print("[INFO] OpenAI client initialized.")
47
- print("[INFO] VerilogAgent initialized successfully.")
48
-
49
- def _build_prompt_messages_rag(self, query: str, docs: list = None) -> list:
50
- context_section = ""
51
- if docs:
52
- context = "\n\n".join([doc.page_content for doc in docs])
53
- context_section = f"""
54
- CONTEXT EXAMPLES:
55
- ```verilog
56
- {context}
57
- ```"""
58
-
59
- system_prompt = """You are an expert Verilog code generation assistant.
60
- TASK: Generate fully implemented, syntactically correct Verilog code in response to the user request.
61
-
62
- INSTRUCTIONS:
63
- 1. Analyze the user request and the provided context examples to determine the required modules and logic.
64
- 2. The context provides examples of valid, complete Verilog modules.
65
- 3. Implement all required modules. Every `module ... endmodule` block must be complete.
66
- 4. Do not leave logic empty or use placeholders like `// your code here`.
67
- 5. Your entire response MUST be only the Verilog code, wrapped in a single `verilog` markdown block. Do not include any natural language explanations.
68
- 6. Use only Verilog-2005 syntax. Do not use SystemVerilog constructs (e.g., `logic`, `always_ff`).
69
- 7. Ensure all identifiers are declared before use and vector ranges are ordered [MSB:LSB].
70
- """
71
- user_prompt = f'''TASK: Generate a fully implemented, syntactically correct Verilog module named 'TopModule'. This name is a strict requirement.
72
- {context_section}
73
-
74
- USER REQUEST:
75
- """{query}"""
76
-
77
- OUTPUT:
78
- Generate the complete Verilog code.
79
- '''
80
- return [
81
- {"role": "system", "content": system_prompt},
82
- {"role": "user", "content": user_prompt}
83
- ]
84
-
85
- def _build_prompt_messages_baseline(self, query: str, docs: list = None) -> list:
86
- context_section = ""
87
- if docs:
88
- context = "\n\n".join([doc.page_content for doc in docs])
89
- context_section = f"""
90
- CONTEXT EXAMPLES:
91
- ```verilog
92
- {context}
93
- ```"""
94
- system_prompt = "You are an expert Verilog code generation assistant."
95
- user_prompt = f'''TASK: Generate a fully implemented, syntactically correct Verilog module named 'TopModule'. This name is a strict requirement.
96
- {context_section}
97
-
98
- USER REQUEST:
99
- """{query}"""
100
-
101
- OUTPUT:
102
- Generate the complete Verilog code.
103
- '''
104
- return [
105
- {"role": "system", "content": system_prompt},
106
- {"role": "user", "content": user_prompt}
107
- ]
108
-
109
- def _extract_verilog_code(self, text: str) -> str:
110
- verilog_pattern = re.compile(r"```(?:verilog\s*)?(.*?)\s*```", re.DOTALL)
111
- match = verilog_pattern.search(text)
112
- if match:
113
- return match.group(1).strip()
114
- module_pattern = re.compile(r"(module.*?endmodule)", re.DOTALL)
115
- match = module_pattern.search(text)
116
- if match:
117
- return match.group(1).strip()
118
- return text.strip()
119
-
120
- def _call_api(self, messages: list, generation_params: dict) -> str:
121
- try:
122
- api_params = {
123
- "model": self.model_id,
124
- "messages": messages,
125
- "max_tokens": generation_params.get("max_new_tokens"),
126
- "temperature": generation_params.get("temperature"),
127
- "top_p": generation_params.get("top_p")
128
- }
129
- if "gpt-5" in self.model_id and "verbosity" in generation_params:
130
- api_params["verbosity"] = generation_params["verbosity"]
131
-
132
- completion = self.client.chat.completions.create(**api_params)
133
- return self._extract_verilog_code(completion.choices[0].message.content)
134
- except Exception as e:
135
- print(f"[ERROR] Code generation failed for model {self.model_id}: {e}")
136
- return f"// ERROR: Generation failed. Details: {e}"
137
-
138
- def generate_with_context(self, spec: str, docs_with_scores: list, generation_params: dict) -> str:
139
- relevant_docs = [doc for doc, score in docs_with_scores]
140
- messages = self._build_prompt_messages_rag(spec, relevant_docs)
141
- return self._call_api(messages, generation_params)
142
-
143
- def generate_baseline(self, spec: str, generation_params: dict) -> str:
144
- messages = self._build_prompt_messages_baseline(spec, docs=[])
145
- return self._call_api(messages, generation_params)
146
-
147
-
148
- # --------------------------- Space wiring below ---------------------------
149
-
150
- CACHE_DIR = Path("./faiss_index") # Stores index in repo's ephemeral environment
151
- CACHE_DIR.mkdir(parents=True, exist_ok=True)
152
-
153
- # Env vars you’ll set in the Space “Settings → Repository secrets”
154
- HF_TOKEN = os.getenv("HF_TOKEN") # personal access token with read permission
155
- PRIVATE_DATASET_ID = os.getenv("PRIVATE_DATASET_ID") # e.g. "yourname/VerilogDB_faiss"
156
- INDEX_SUBDIR = os.getenv("INDEX_SUBDIR", ".") # since your files are at repo root
157
- EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
158
-
159
- def ensure_index_downloaded() -> Path:
160
- """
161
- Downloads your private dataset (FAISS index + artifacts) once per container.
162
- Avoids committing large binaries to the public Space repo.
163
- """
164
- target = CACHE_DIR / INDEX_SUBDIR
165
- if target.exists() and any(target.iterdir()):
166
- print(f"[INFO] Using cached FAISS index at {target}")
167
- return target
168
-
169
- if not HF_TOKEN:
170
- raise RuntimeError("Missing HF_TOKEN secret. Add it in the Space settings.")
171
- if not PRIVATE_DATASET_ID:
172
- raise RuntimeError("Missing PRIVATE_DATASET_ID secret (e.g., 'user/VerilogDB_faiss').")
173
-
174
- print(f"[INFO] Downloading private dataset: {PRIVATE_DATASET_ID}")
175
- snapshot_path = snapshot_download(
176
  repo_id=PRIVATE_DATASET_ID,
177
- repo_type="dataset",
178
- token=HF_TOKEN,
179
- local_dir=str(CACHE_DIR),
180
- local_dir_use_symlinks=False,
181
  )
 
 
 
 
 
 
 
 
 
 
182
 
183
- # Walk the downloaded directory to find the index files
184
- downloaded_dir = Path(snapshot_path)
185
- for root, dirs, files in os.walk(downloaded_dir):
186
- if 'index.faiss' in files and 'index.pkl' in files:
187
- index_path = Path(root)
188
- print(f"[INFO] Found index files at {index_path}. Using this path.")
189
- return index_path
190
-
191
- # If the files are not found, raise an error
192
- raise FileNotFoundError("FAISS index files (index.faiss and index.pkl) not found in the downloaded dataset.")
193
-
194
 
195
  # Keep a lightweight global cache so we don’t reload embeddings on every click
196
  _VECTORSTORE_PATH = None
@@ -201,58 +43,12 @@ def get_vectorstore_path() -> Path:
201
  _VECTORSTORE_PATH = ensure_index_downloaded()
202
  return _VECTORSTORE_PATH
203
 
204
- def run_generation(spec, use_rag, top_k, model_choice, api_key, temperature, top_p, max_new_tokens):
205
- if not spec or not api_key:
206
- return "// Please provide a design specification and your API key.", "", []
207
-
208
- # Prepare agent
209
- try:
210
- faiss_path = get_vectorstore_path()
211
- agent = VerilogAgent(
212
- model_id=model_choice,
213
- embedding_id=EMBEDDING_MODEL,
214
- faiss_index_path=str(faiss_path),
215
- api_key=api_key.strip()
216
- )
217
- except Exception as e:
218
- return f"// Initialization error: {e}", "", []
219
-
220
- # Retrieval (if enabled)
221
- docs_with_scores = []
222
- retrieved_preview = []
223
- retrieved_raw_formatted = [] # New list to hold formatted data
224
- if use_rag:
225
- try:
226
- docs_with_scores = agent.vectorstore.similarity_search_with_score(spec, k=top_k)
227
- for doc, score in docs_with_scores:
228
- src = doc.metadata.get("source_file", doc.metadata.get("module", "unknown"))
229
- retrieved_preview.append(f"{src} | score={score:.4f}")
230
- # Add the page content to the new list, formatted as a tuple
231
- retrieved_raw_formatted.append((doc.page_content, None))
232
- except Exception as e:
233
- return f"// Retrieval error: {e}", "", []
234
-
235
- # Call model
236
- gen_params = {
237
- "temperature": float(temperature),
238
- "top_p": float(top_p),
239
- "max_new_tokens": int(max_new_tokens),
240
- }
241
-
242
- if use_rag:
243
- code = agent.generate_with_context(spec, docs_with_scores, gen_params)
244
- else:
245
- code = agent.generate_baseline(spec, gen_params)
246
-
247
- # Return the new formatted list for the HighlightedText component
248
- return code.strip(), ("\n".join(retrieved_preview) if retrieved_preview else ""), retrieved_raw_formatted
249
-
250
  with gr.Blocks(title="DeepV for RTL (Model-Agnostic)") as demo:
251
  gr.Markdown("## DeepV for RTL Code Generation — Model-Agnostic (Bring Your Own API Key)")
252
 
253
  with gr.Row():
254
  with gr.Column(scale=2):
255
- # Moved model choice and API key to the top of the left column
256
  with gr.Row():
257
  model_choice = gr.Dropdown(
258
  choices=[
@@ -260,13 +56,18 @@ with gr.Blocks(title="DeepV for RTL (Model-Agnostic)") as demo:
260
  "gpt-4o-mini",
261
  "gpt-4.1",
262
  "gpt-5-chat-latest",
263
- "gpt-5-mini"
264
  ],
265
  value="gpt-4o",
266
  label="Model"
267
  )
268
  api_key = gr.Textbox(label="OpenAI API Key", type="password", placeholder="sk-...")
269
 
 
 
 
 
 
 
270
  spec = gr.Textbox(
271
  label="Design Specification (natural language or I/O contract)",
272
  placeholder="e.g., 8-bit UART transmitter with baud rate generator ...",
@@ -306,4 +107,11 @@ with gr.Blocks(title="DeepV for RTL (Model-Agnostic)") as demo:
306
  )
307
 
308
  if __name__ == "__main__":
309
- demo.launch()
 
 
 
 
 
 
 
 
1
  import os
 
2
  from pathlib import Path
3
  import gradio as gr
4
+ from huggingface_hub import snapshot_download, hf_hub_download
5
+ import importlib.util
6
 
7
+ # This is the path to your private dataset repository on Hugging Face Hub
8
+ PRIVATE_DATASET_ID = os.getenv("PRIVATE_DATASET_ID")
9
+ HF_TOKEN = os.getenv("HF_TOKEN")
10
+ INDEX_SUBDIR = os.getenv("INDEX_SUBDIR", ".")
11
 
12
+ # --- Core Logic Download and Import ---
 
 
13
  try:
14
+ # First, download the core agent code from the private repo
15
+ AGENT_CODE_PATH = hf_hub_download(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  repo_id=PRIVATE_DATASET_ID,
17
+ filename="deepv_core.py", # The file containing your VerilogAgent class
18
+ repo_type="dataset", # Ensure this is 'dataset' to match your repo type
19
+ token=HF_TOKEN
 
20
  )
21
+ # Dynamically load the agent module from the downloaded file
22
+ spec = importlib.util.spec_from_file_location("deepv_core_module", AGENT_CODE_PATH)
23
+ agent_module = importlib.util.module_from_spec(spec)
24
+ spec.loader.exec_module(agent_module)
25
+
26
+ # Now you can access the functions and classes from the private module
27
+ VerilogAgent = agent_module.VerilogAgent
28
+ run_generation = agent_module.run_generation
29
+ get_vectorstore_path = agent_module.get_vectorstore_path
30
+ ensure_index_downloaded = agent_module.ensure_index_downloaded
31
 
32
+ except Exception as e:
33
+ # Handle the error gracefully if the private repo can't be accessed
34
+ def show_error(*args):
35
+ return f"// ERROR: Failed to load core agent code. Check your Hugging Face token and private dataset configuration. Details: {e}", "", []
 
 
 
 
 
 
 
36
 
37
  # Keep a lightweight global cache so we don’t reload embeddings on every click
38
  _VECTORSTORE_PATH = None
 
43
  _VECTORSTORE_PATH = ensure_index_downloaded()
44
  return _VECTORSTORE_PATH
45
 
46
+ # --- Gradio UI setup below ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  with gr.Blocks(title="DeepV for RTL (Model-Agnostic)") as demo:
48
  gr.Markdown("## DeepV for RTL Code Generation — Model-Agnostic (Bring Your Own API Key)")
49
 
50
  with gr.Row():
51
  with gr.Column(scale=2):
 
52
  with gr.Row():
53
  model_choice = gr.Dropdown(
54
  choices=[
 
56
  "gpt-4o-mini",
57
  "gpt-4.1",
58
  "gpt-5-chat-latest",
 
59
  ],
60
  value="gpt-4o",
61
  label="Model"
62
  )
63
  api_key = gr.Textbox(label="OpenAI API Key", type="password", placeholder="sk-...")
64
 
65
+ gr.Markdown(
66
+ """
67
+ **Note:** Your API key is used for the current session only and is not saved or stored.
68
+ """
69
+ )
70
+
71
  spec = gr.Textbox(
72
  label="Design Specification (natural language or I/O contract)",
73
  placeholder="e.g., 8-bit UART transmitter with baud rate generator ...",
 
107
  )
108
 
109
  if __name__ == "__main__":
110
+ if 'agent_module' in locals():
111
+ demo.launch()
112
+ else:
113
+ with gr.Blocks() as error_demo:
114
+ gr.Markdown("# Initialization Error")
115
+ gr.Markdown(f"An error occurred while loading the application code. Please check your configuration.")
116
+ gr.Textbox(label="Error Details", value=str(e), lines=5)
117
+ error_demo.launch()