Add OpenAI chatbot
Browse files- .gitignore +2 -1
- autohdl/data.py +13 -0
- autohdl/linter.py +23 -0
- autohdl/llm.py +52 -5
- autohdl/test_llm.py +50 -0
- requirements.txt +3 -1
.gitignore
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
__pycache__
|
| 2 |
.venv
|
| 3 |
venv
|
| 4 |
-
.vscode
|
|
|
|
|
|
| 1 |
__pycache__
|
| 2 |
.venv
|
| 3 |
venv
|
| 4 |
+
.vscode
|
| 5 |
+
.env
|
autohdl/data.py
CHANGED
|
@@ -1,6 +1,19 @@
|
|
| 1 |
import re
|
| 2 |
from datasets import load_dataset
|
| 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
def extract_description(prompt):
|
| 5 |
"""Use regex to extract description from prompt
|
| 6 |
|
|
|
|
| 1 |
import re
|
| 2 |
from datasets import load_dataset
|
| 3 |
|
| 4 |
+
|
| 5 |
+
def extract_header(prompt):
|
| 6 |
+
"""Use regex to extract module header from prompt"""
|
| 7 |
+
|
| 8 |
+
module_re = re.compile(r"module.*;")
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
result = re.search(module_re, prompt).group(0)
|
| 12 |
+
except:
|
| 13 |
+
raise Exception("Prompt is not in expected format when extracting header")
|
| 14 |
+
|
| 15 |
+
return result.strip()
|
| 16 |
+
|
| 17 |
def extract_description(prompt):
|
| 18 |
"""Use regex to extract description from prompt
|
| 19 |
|
autohdl/linter.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import subprocess
|
| 2 |
+
|
| 3 |
+
def linter(code:str)-> str:
|
| 4 |
+
"""A tool that runs Verilog syntax linting (with Verilator) on input code string
|
| 5 |
+
Args:
|
| 6 |
+
code: A string representing the Verilog code to be linted
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
with open("code.v", "w") as f:
|
| 10 |
+
f.write(code)
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
subprocess.run(["verilator", "--lint-only", "-Wall", "code.v"],
|
| 14 |
+
stdout=subprocess.PIPE,
|
| 15 |
+
stderr=subprocess.STDOUT,
|
| 16 |
+
check=True,
|
| 17 |
+
encoding="utf-8")
|
| 18 |
+
|
| 19 |
+
return "The linting passed successfully - no changes necessary"
|
| 20 |
+
|
| 21 |
+
except subprocess.CalledProcessError as e:
|
| 22 |
+
return f"Verilator linting gave an error. Please investigate and fix: {e.stdout}"
|
| 23 |
+
|
autohdl/llm.py
CHANGED
|
@@ -1,15 +1,62 @@
|
|
| 1 |
import outlines
|
| 2 |
from outlines.inputs import Chat
|
| 3 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 4 |
-
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
system_prompt = ("You only complete chats with syntax correct Verilog code. "
|
| 7 |
-
"End the Verilog module code completion with 'endmodule'. "
|
| 8 |
-
"Do not include module, input and output definitions.")
|
| 9 |
|
| 10 |
class LLM:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
def __init__(self):
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
def load_model(self, use_cpu=True):
|
| 15 |
if use_cpu:
|
|
|
|
| 1 |
import outlines
|
| 2 |
from outlines.inputs import Chat
|
| 3 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 4 |
+
import openai
|
| 5 |
+
import tiktoken
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
class LLM:
|
| 10 |
+
system_prompt = ("You only complete chats with syntax correct Verilog code. "
|
| 11 |
+
"End the Verilog module code completion with 'endmodule'. "
|
| 12 |
+
"Do not include module, input and output definitions.")
|
| 13 |
+
|
| 14 |
def __init__(self):
|
| 15 |
+
load_dotenv()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class OpenAI(LLM):
|
| 19 |
+
def __init__(self, system_prompt=None, max_context_len=1000, model="gpt-5-nano"):
|
| 20 |
+
super().__init__()
|
| 21 |
+
|
| 22 |
+
if system_prompt:
|
| 23 |
+
self.system_prompt = system_prompt
|
| 24 |
+
|
| 25 |
+
self.max_context_len = max_context_len
|
| 26 |
+
self.model = model
|
| 27 |
+
|
| 28 |
+
self.messages = [{"role": "system", "content": self.system_prompt}]
|
| 29 |
+
|
| 30 |
+
self.client = openai.OpenAI()
|
| 31 |
+
|
| 32 |
+
def __call__(self, message):
|
| 33 |
+
self.messages.append({"role": "user", "content": message})
|
| 34 |
+
|
| 35 |
+
response = self.client.responses.create(
|
| 36 |
+
model=self.model,
|
| 37 |
+
input=self.messages,
|
| 38 |
+
service_tier="flex"
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
self.messages.append({"role": "assistant", "content": response.output_text})
|
| 42 |
+
|
| 43 |
+
return response.output_text
|
| 44 |
+
|
| 45 |
+
def truncate(self):
|
| 46 |
+
"""Truncate off tokens until we reach max_context_len
|
| 47 |
+
|
| 48 |
+
Approach is to first determine the index where we are below the max token length
|
| 49 |
+
in the messages array. Once we find this index, then we truncate off that content until
|
| 50 |
+
we are below the max.
|
| 51 |
+
"""
|
| 52 |
+
total_len = sum(len(message['content']) for message in self.messages)
|
| 53 |
+
|
| 54 |
+
i = len(self.messages)-1
|
| 55 |
+
while total_len > self.max_context_len:
|
| 56 |
+
pass
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class HuggingFace(LLM):
|
| 60 |
|
| 61 |
def load_model(self, use_cpu=True):
|
| 62 |
if use_cpu:
|
autohdl/test_llm.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .llm import OpenAI
|
| 2 |
+
from .linter import linter
|
| 3 |
+
from .data import data, extract_header
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def main():
|
| 7 |
+
llm = OpenAI()
|
| 8 |
+
|
| 9 |
+
ds = data()
|
| 10 |
+
|
| 11 |
+
message = ds['description'][0]["high_level_global_summary"]
|
| 12 |
+
expected = ds['code'][0]
|
| 13 |
+
|
| 14 |
+
print("-------------------------------------------")
|
| 15 |
+
print("System prompt:")
|
| 16 |
+
print(llm.messages[0]['content'])
|
| 17 |
+
print("-------------------------------------------\n")
|
| 18 |
+
|
| 19 |
+
print("-------------------------------------------")
|
| 20 |
+
print("User prompt:\n", message)
|
| 21 |
+
print("-------------------------------------------\n")
|
| 22 |
+
|
| 23 |
+
print("-------------------------------------------")
|
| 24 |
+
print("Expected:\n", expected)
|
| 25 |
+
print("-------------------------------------------\n")
|
| 26 |
+
|
| 27 |
+
response = llm(message)
|
| 28 |
+
print("-------------------------------------------")
|
| 29 |
+
print("Response:\n", response)
|
| 30 |
+
print("-------------------------------------------\n")
|
| 31 |
+
|
| 32 |
+
header = extract_header(message)
|
| 33 |
+
lint_input = header + "\n" + response
|
| 34 |
+
print("-------------------------------------------")
|
| 35 |
+
print("Linter input:\n", lint_input)
|
| 36 |
+
print("-------------------------------------------\n")
|
| 37 |
+
|
| 38 |
+
lint_result = linter(lint_input)
|
| 39 |
+
print("-------------------------------------------")
|
| 40 |
+
print("Linter result:\n", lint_result)
|
| 41 |
+
print("-------------------------------------------\n")
|
| 42 |
+
|
| 43 |
+
second_response = llm(lint_result)
|
| 44 |
+
print("-------------------------------------------")
|
| 45 |
+
print("Second response:\n", second_response)
|
| 46 |
+
print("-------------------------------------------\n")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if __name__ == "__main__":
|
| 50 |
+
main()
|
requirements.txt
CHANGED
|
@@ -2,4 +2,6 @@ transformers
|
|
| 2 |
datasets
|
| 3 |
outlines[transformers]
|
| 4 |
streamlit
|
| 5 |
-
streamlit_code_editor
|
|
|
|
|
|
|
|
|
| 2 |
datasets
|
| 3 |
outlines[transformers]
|
| 4 |
streamlit
|
| 5 |
+
streamlit_code_editor
|
| 6 |
+
python-dotenv
|
| 7 |
+
openai
|