semiconductor-pipeline / src /models /llm_generator.py
Sai Kumar Taraka
feat: Add actual AI/ML capabilities with LLM, semantic embeddings, and reinforcement learning
9e8e9e2
import logging
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass, field
from enum import Enum
import json
import re
logger = logging.getLogger("uvmgen.ml.llm")
class LLMType(Enum):
CODEGEN = "codegen"
CODET5 = "codet5"
CODEBERT = "codebert"
STARCODER = "starcoder"
LLAMA = "llama"
MISTRAL = "mistral"
FALLBACK = "fallback"
@dataclass
class LLMGenerationResult:
generated_code: str
prompt_used: str
model_name: str
tokens_generated: int
confidence: float = 0.5
metadata: Dict[str, Any] = field(default_factory=dict)
warnings: List[str] = field(default_factory=list)
errors: List[str] = field(default_factory=list)
class LLMCodeGenerator:
_instance: Optional["LLMCodeGenerator"] = None
_model = None
_tokenizer = None
_model_name: str = "Salesforce/codegen-350M-mono"
_device: str = "cpu"
_initialized: bool = False
_llm_type: LLMType = LLMType.FALLBACK
UVM_PROMPT_TEMPLATE = """
You are an expert in UVM (Universal Verification Methodology) and SystemVerilog.
Generate production-quality UVM testbench code based on the following specification.
SPECIFICATION:
{spec_text}
REQUIREMENTS:
- Follow UVM 1.2 conventions and best practices
- Use proper factory registration with `uvm_component_utils` or `uvm_object_utils`
- Include appropriate phases (build_phase, connect_phase, run_phase)
- Use TLM ports and exports for component communication
- Include proper configuration database usage if needed
- Generate synthesizable SystemVerilog code
{context_examples}
Generate the {file_type} for this specification. Return only the SystemVerilog code, no explanations.
"""
FEW_SHOT_EXAMPLES = {
"driver": """
EXAMPLE DRIVER:
class my_driver extends uvm_driver #(my_seq_item);
`uvm_component_utils(my_driver)
virtual my_if vif;
function new(string name = "my_driver", uvm_component parent = null);
super.new(name, parent);
endfunction
function void build_phase(uvm_phase phase);
super.build_phase(phase);
if (!uvm_config_db#(virtual my_if)::get(this, "", "vif", vif))
`uvm_fatal(get_type_name(), "Virtual interface not found")
endfunction
task run_phase(uvm_phase phase);
forever begin
seq_item_port.get_next_item(req);
drive_item(req);
seq_item_port.item_done();
end
endtask
task drive_item(my_seq_item item);
@(posedge vif.clk);
vif.valid <= 1'b1;
vif.data <= item.data;
@(posedge vif.clk);
vif.valid <= 1'b0;
endtask
endclass
""",
"monitor": """
EXAMPLE MONITOR:
class my_monitor extends uvm_monitor;
`uvm_component_utils(my_monitor)
uvm_analysis_port #(my_seq_item) item_collected_port;
virtual my_if vif;
function new(string name = "my_monitor", uvm_component parent = null);
super.new(name, parent);
item_collected_port = new("item_collected_port", this);
endfunction
function void build_phase(uvm_phase phase);
super.build_phase(phase);
if (!uvm_config_db#(virtual my_if)::get(this, "", "vif", vif))
`uvm_fatal(get_type_name(), "Virtual interface not found")
endfunction
task run_phase(uvm_phase phase);
my_seq_item item;
forever begin
@(posedge vif.clk);
if (vif.valid) begin
item = my_seq_item::type_id::create("item");
item.data = vif.data;
item_collected_port.write(item);
end
end
endtask
endclass
""",
"agent": """
EXAMPLE AGENT:
class my_agent extends uvm_agent;
`uvm_component_utils(my_agent)
my_driver driver;
my_monitor monitor;
my_sequencer sequencer;
uvm_analysis_port #(my_seq_item) item_collected_port;
function new(string name = "my_agent", uvm_component parent = null);
super.new(name, parent);
item_collected_port = new("item_collected_port", this);
endfunction
function void build_phase(uvm_phase phase);
super.build_phase(phase);
if (get_is_active() == UVM_ACTIVE) begin
driver = my_driver::type_id::create("driver", this);
sequencer = my_sequencer::type_id::create("sequencer", this);
end
monitor = my_monitor::type_id::create("monitor", this);
endfunction
function void connect_phase(uvm_phase phase);
super.connect_phase(phase);
if (get_is_active() == UVM_ACTIVE) begin
driver.seq_item_port.connect(sequencer.seq_item_export);
end
monitor.item_collected_port.connect(item_collected_port);
endfunction
endclass
""",
}
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, model_name: Optional[str] = None, device: Optional[str] = None):
if self._initialized:
return
if model_name:
self._model_name = model_name
if device:
self._device = device
self._initialized = False
self._model = None
self._tokenizer = None
self._detect_llm_type()
def _detect_llm_type(self):
name_lower = self._model_name.lower()
if "codegen" in name_lower:
self._llm_type = LLMType.CODEGEN
elif "codet5" in name_lower:
self._llm_type = LLMType.CODET5
elif "codebert" in name_lower:
self._llm_type = LLMType.CODEBERT
elif "starcoder" in name_lower or "starcoder" in name_lower:
self._llm_type = LLMType.STARCODER
elif "llama" in name_lower:
self._llm_type = LLMType.LLAMA
elif "mistral" in name_lower:
self._llm_type = LLMType.MISTRAL
else:
self._llm_type = LLMType.FALLBACK
def _load_model(self):
if self._initialized and self._model is not None:
return
if self._llm_type == LLMType.FALLBACK:
logger.info("LLMCodeGenerator using fallback mode (template-based)")
self._initialized = True
return
try:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
if self._device == "auto":
self._device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info("Loading LLM: %s on %s", self._model_name, self._device)
self._tokenizer = AutoTokenizer.from_pretrained(self._model_name)
if self._llm_type == LLMType.CODET5:
self._model = AutoModelForSeq2SeqLM.from_pretrained(
self._model_name,
torch_dtype=torch.float16 if self._device == "cuda" else torch.float32,
)
else:
self._model = AutoModelForCausalLM.from_pretrained(
self._model_name,
torch_dtype=torch.float16 if self._device == "cuda" else torch.float32,
)
self._model.to(self._device)
self._model.eval()
if self._tokenizer.pad_token is None:
self._tokenizer.pad_token = self._tokenizer.eos_token
self._initialized = True
logger.info("LLM loaded successfully")
except ImportError as e:
logger.warning(
"Could not load LLM (missing dependencies: %s). Using fallback mode.",
e,
)
self._llm_type = LLMType.FALLBACK
self._initialized = True
except Exception as e:
logger.warning(
"Could not load LLM (%s). Using fallback mode.",
e,
)
self._llm_type = LLMType.FALLBACK
self._initialized = True
def is_available(self) -> bool:
self._load_model()
return self._initialized and self._llm_type != LLMType.FALLBACK
def _spec_to_text(self, spec_dict: Dict[str, Any]) -> str:
lines = []
if "design_name" in spec_dict:
lines.append(f"Design Name: {spec_dict['design_name']}")
if "protocol" in spec_dict:
lines.append(f"Protocol: {spec_dict['protocol']}")
if "signals" in spec_dict:
lines.append("\nSignals:")
for sig in spec_dict["signals"]:
name = sig.get("name", "unknown")
direction = sig.get("direction", "inout")
width = sig.get("width", 1)
desc = sig.get("description", "")
lines.append(f" - {name}: {direction}, width={width} {desc}")
if "registers" in spec_dict:
lines.append("\nRegisters:")
for reg in spec_dict["registers"]:
name = reg.get("name", "unknown")
addr = reg.get("address", "0x0")
width = reg.get("width", 32)
lines.append(f" - {name}: addr={addr}, width={width}")
if "features" in spec_dict:
lines.append("\nFeatures:")
for feat in spec_dict["features"]:
lines.append(f" - {feat}")
return "\n".join(lines)
def _build_prompt(
self,
spec_dict: Dict[str, Any],
file_type: str,
use_few_shot: bool = True,
) -> str:
spec_text = self._spec_to_text(spec_dict)
context_examples = ""
if use_few_shot and file_type in self.FEW_SHOT_EXAMPLES:
context_examples = self.FEW_SHOT_EXAMPLES[file_type]
prompt = self.UVM_PROMPT_TEMPLATE.format(
spec_text=spec_text,
file_type=file_type,
context_examples=context_examples,
)
return prompt.strip()
def _extract_code(self, text: str) -> str:
code_block_patterns = [
r"```systemverilog\s+(.*?)```",
r"```verilog\s+(.*?)```",
r"```sv\s+(.*?)```",
r"```\s+(.*?)```",
]
for pattern in code_block_patterns:
match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
if match:
return match.group(1).strip()
return text.strip()
def _fallback_generate(
self,
spec_dict: Dict[str, Any],
file_type: str,
templates: Optional[Dict[str, str]] = None,
) -> LLMGenerationResult:
design_name = spec_dict.get("design_name", "unknown").lower()
fallback_templates = {
"driver": f"""
class {design_name}_driver extends uvm_driver #({design_name}_seq_item);
`uvm_component_utils({design_name}_driver)
virtual {design_name}_if vif;
function new(string name = "{design_name}_driver", uvm_component parent = null);
super.new(name, parent);
endfunction
function void build_phase(uvm_phase phase);
super.build_phase(phase);
if (!uvm_config_db#(virtual {design_name}_if)::get(this, "", "vif", vif))
`uvm_fatal(get_type_name(), "Virtual interface not found in config DB")
endfunction
task run_phase(uvm_phase phase);
forever begin
seq_item_port.get_next_item(req);
drive_item(req);
seq_item_port.item_done();
end
endtask
task drive_item({design_name}_seq_item item);
// Implement drive logic based on item
@(posedge vif.clk);
endtask
endclass
""",
"monitor": f"""
class {design_name}_monitor extends uvm_monitor;
`uvm_component_utils({design_name}_monitor)
uvm_analysis_port #({design_name}_seq_item) item_collected_port;
virtual {design_name}_if vif;
function new(string name = "{design_name}_monitor", uvm_component parent = null);
super.new(name, parent);
item_collected_port = new("item_collected_port", this);
endfunction
function void build_phase(uvm_phase phase);
super.build_phase(phase);
if (!uvm_config_db#(virtual {design_name}_if)::get(this, "", "vif", vif))
`uvm_fatal(get_type_name(), "Virtual interface not found in config DB")
endfunction
task run_phase(uvm_phase phase);
{design_name}_seq_item item;
forever begin
@(posedge vif.clk);
// Sample signals and create item
end
endtask
endclass
""",
"agent": f"""
class {design_name}_agent extends uvm_agent;
`uvm_component_utils({design_name}_agent)
{design_name}_driver driver;
{design_name}_monitor monitor;
{design_name}_sequencer sequencer;
uvm_analysis_port #({design_name}_seq_item) item_collected_port;
function new(string name = "{design_name}_agent", uvm_component parent = null);
super.new(name, parent);
item_collected_port = new("item_collected_port", this);
endfunction
function void build_phase(uvm_phase phase);
super.build_phase(phase);
if (get_is_active() == UVM_ACTIVE) begin
driver = {design_name}_driver::type_id::create("driver", this);
sequencer = {design_name}_sequencer::type_id::create("sequencer", this);
end
monitor = {design_name}_monitor::type_id::create("monitor", this);
endfunction
function void connect_phase(uvm_phase phase);
super.connect_phase(phase);
if (get_is_active() == UVM_ACTIVE) begin
driver.seq_item_port.connect(sequencer.seq_item_export);
end
monitor.item_collected_port.connect(item_collected_port);
endfunction
endclass
""",
}
if templates and file_type in templates:
code = templates[file_type]
elif file_type in fallback_templates:
code = fallback_templates[file_type]
else:
code = f"// {file_type} for {design_name} - template placeholder"
return LLMGenerationResult(
generated_code=code,
prompt_used=f"// Fallback generation for {file_type}",
model_name="fallback_template",
tokens_generated=len(code.split()),
confidence=0.3,
warnings=["Using fallback template generation (LLM not available)"],
)
def generate(
self,
spec_dict: Dict[str, Any],
file_type: str,
use_few_shot: bool = True,
max_tokens: int = 1024,
temperature: float = 0.2,
templates: Optional[Dict[str, str]] = None,
) -> LLMGenerationResult:
self._load_model()
prompt = self._build_prompt(spec_dict, file_type, use_few_shot)
if self._llm_type == LLMType.FALLBACK or self._model is None:
return self._fallback_generate(spec_dict, file_type, templates)
try:
import torch
inputs = self._tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=1024,
padding=True,
)
inputs = {k: v.to(self._device) for k, v in inputs.items()}
with torch.no_grad():
if self._llm_type == LLMType.CODET5:
outputs = self._model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
do_sample=temperature > 0,
num_return_sequences=1,
pad_token_id=self._tokenizer.pad_token_id,
eos_token_id=self._tokenizer.eos_token_id,
)
else:
outputs = self._model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
do_sample=temperature > 0,
num_return_sequences=1,
pad_token_id=self._tokenizer.pad_token_id,
eos_token_id=self._tokenizer.eos_token_id,
)
generated_text = self._tokenizer.decode(outputs[0], skip_special_tokens=True)
if generated_text.startswith(prompt):
generated_text = generated_text[len(prompt) :].strip()
code = self._extract_code(generated_text)
tokens_generated = len(outputs[0]) - inputs["input_ids"].shape[1]
confidence = 0.7
if "uvm_component_utils" in code or "uvm_object_utils" in code:
confidence += 0.1
if "class" in code and "extends" in code:
confidence += 0.05
if "build_phase" in code or "run_phase" in code:
confidence += 0.05
if "endclass" in code:
confidence += 0.05
confidence = min(confidence, 0.95)
return LLMGenerationResult(
generated_code=code,
prompt_used=prompt,
model_name=self._model_name,
tokens_generated=tokens_generated,
confidence=confidence,
warnings=[],
)
except Exception as e:
logger.warning("Error during LLM generation: %s. Using fallback.", e)
result = self._fallback_generate(spec_dict, file_type, templates)
result.warnings.append(f"LLM generation failed: {str(e)}")
return result
def generate_batch(
self,
spec_dict: Dict[str, Any],
file_types: List[str],
use_few_shot: bool = True,
max_tokens: int = 1024,
temperature: float = 0.2,
templates: Optional[Dict[str, str]] = None,
) -> Dict[str, LLMGenerationResult]:
results = {}
for file_type in file_types:
results[file_type] = self.generate(
spec_dict=spec_dict,
file_type=file_type,
use_few_shot=use_few_shot,
max_tokens=max_tokens,
temperature=temperature,
templates=templates.get(file_type) if templates else None,
)
return results