j-silv commited on
Commit
b03416e
·
1 Parent(s): b4ab4e0

Add OpenAI chatbot

Browse files
Files changed (6) hide show
  1. .gitignore +2 -1
  2. autohdl/data.py +13 -0
  3. autohdl/linter.py +23 -0
  4. autohdl/llm.py +52 -5
  5. autohdl/test_llm.py +50 -0
  6. requirements.txt +3 -1
.gitignore CHANGED
@@ -1,4 +1,5 @@
1
  __pycache__
2
  .venv
3
  venv
4
- .vscode
 
 
1
  __pycache__
2
  .venv
3
  venv
4
+ .vscode
5
+ .env
autohdl/data.py CHANGED
@@ -1,6 +1,19 @@
1
  import re
2
  from datasets import load_dataset
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  def extract_description(prompt):
5
  """Use regex to extract description from prompt
6
 
 
1
  import re
2
  from datasets import load_dataset
3
 
4
+
5
+ def extract_header(prompt):
6
+ """Use regex to extract module header from prompt"""
7
+
8
+ module_re = re.compile(r"module.*;")
9
+
10
+ try:
11
+ result = re.search(module_re, prompt).group(0)
12
+ except:
13
+ raise Exception("Prompt is not in expected format when extracting header")
14
+
15
+ return result.strip()
16
+
17
  def extract_description(prompt):
18
  """Use regex to extract description from prompt
19
 
autohdl/linter.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+
3
+ def linter(code:str)-> str:
4
+ """A tool that runs Verilog syntax linting (with Verilator) on input code string
5
+ Args:
6
+ code: A string representing the Verilog code to be linted
7
+ """
8
+
9
+ with open("code.v", "w") as f:
10
+ f.write(code)
11
+
12
+ try:
13
+ subprocess.run(["verilator", "--lint-only", "-Wall", "code.v"],
14
+ stdout=subprocess.PIPE,
15
+ stderr=subprocess.STDOUT,
16
+ check=True,
17
+ encoding="utf-8")
18
+
19
+ return "The linting passed successfully - no changes necessary"
20
+
21
+ except subprocess.CalledProcessError as e:
22
+ return f"Verilator linting gave an error. Please investigate and fix: {e.stdout}"
23
+
autohdl/llm.py CHANGED
@@ -1,15 +1,62 @@
1
  import outlines
2
  from outlines.inputs import Chat
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
- from .data import data
 
 
5
 
6
- system_prompt = ("You only complete chats with syntax correct Verilog code. "
7
- "End the Verilog module code completion with 'endmodule'. "
8
- "Do not include module, input and output definitions.")
9
 
10
  class LLM:
 
 
 
 
11
  def __init__(self):
12
- self.system_prompt = system_prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def load_model(self, use_cpu=True):
15
  if use_cpu:
 
1
  import outlines
2
  from outlines.inputs import Chat
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import openai
5
+ import tiktoken
6
+ from dotenv import load_dotenv
7
 
 
 
 
8
 
9
  class LLM:
10
+ system_prompt = ("You only complete chats with syntax correct Verilog code. "
11
+ "End the Verilog module code completion with 'endmodule'. "
12
+ "Do not include module, input and output definitions.")
13
+
14
  def __init__(self):
15
+ load_dotenv()
16
+
17
+
18
+ class OpenAI(LLM):
19
+ def __init__(self, system_prompt=None, max_context_len=1000, model="gpt-5-nano"):
20
+ super().__init__()
21
+
22
+ if system_prompt:
23
+ self.system_prompt = system_prompt
24
+
25
+ self.max_context_len = max_context_len
26
+ self.model = model
27
+
28
+ self.messages = [{"role": "system", "content": self.system_prompt}]
29
+
30
+ self.client = openai.OpenAI()
31
+
32
+ def __call__(self, message):
33
+ self.messages.append({"role": "user", "content": message})
34
+
35
+ response = self.client.responses.create(
36
+ model=self.model,
37
+ input=self.messages,
38
+ service_tier="flex"
39
+ )
40
+
41
+ self.messages.append({"role": "assistant", "content": response.output_text})
42
+
43
+ return response.output_text
44
+
45
+ def truncate(self):
46
+ """Truncate off tokens until we reach max_context_len
47
+
48
+ Approach is to first determine the index where we are below the max token length
49
+ in the messages array. Once we find this index, then we truncate off that content until
50
+ we are below the max.
51
+ """
52
+ total_len = sum(len(message['content']) for message in self.messages)
53
+
54
+ i = len(self.messages)-1
55
+ while total_len > self.max_context_len:
56
+ pass
57
+
58
+
59
+ class HuggingFace(LLM):
60
 
61
  def load_model(self, use_cpu=True):
62
  if use_cpu:
autohdl/test_llm.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .llm import OpenAI
2
+ from .linter import linter
3
+ from .data import data, extract_header
4
+
5
+
6
+ def main():
7
+ llm = OpenAI()
8
+
9
+ ds = data()
10
+
11
+ message = ds['description'][0]["high_level_global_summary"]
12
+ expected = ds['code'][0]
13
+
14
+ print("-------------------------------------------")
15
+ print("System prompt:")
16
+ print(llm.messages[0]['content'])
17
+ print("-------------------------------------------\n")
18
+
19
+ print("-------------------------------------------")
20
+ print("User prompt:\n", message)
21
+ print("-------------------------------------------\n")
22
+
23
+ print("-------------------------------------------")
24
+ print("Expected:\n", expected)
25
+ print("-------------------------------------------\n")
26
+
27
+ response = llm(message)
28
+ print("-------------------------------------------")
29
+ print("Response:\n", response)
30
+ print("-------------------------------------------\n")
31
+
32
+ header = extract_header(message)
33
+ lint_input = header + "\n" + response
34
+ print("-------------------------------------------")
35
+ print("Linter input:\n", lint_input)
36
+ print("-------------------------------------------\n")
37
+
38
+ lint_result = linter(lint_input)
39
+ print("-------------------------------------------")
40
+ print("Linter result:\n", lint_result)
41
+ print("-------------------------------------------\n")
42
+
43
+ second_response = llm(lint_result)
44
+ print("-------------------------------------------")
45
+ print("Second response:\n", second_response)
46
+ print("-------------------------------------------\n")
47
+
48
+
49
+ if __name__ == "__main__":
50
+ main()
requirements.txt CHANGED
@@ -2,4 +2,6 @@ transformers
2
  datasets
3
  outlines[transformers]
4
  streamlit
5
- streamlit_code_editor
 
 
 
2
  datasets
3
  outlines[transformers]
4
  streamlit
5
+ streamlit_code_editor
6
+ python-dotenv
7
+ openai