|
|
|
|
|
""" |
|
|
Test the newly fine-tuned CodeLlama model on training samples |
|
|
""" |
|
|
|
|
|
import json |
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent / "scripts" / "inference")) |
|
|
|
|
|
from inference_codellama import load_local_model |
|
|
import torch |
|
|
from transformers import AutoTokenizer |
|
|
import re |
|
|
|
|
|
def extract_code_from_response(text): |
|
|
"""Extract Verilog code from markdown code blocks""" |
|
|
if not text: |
|
|
return text |
|
|
|
|
|
|
|
|
if '```verilog' in text: |
|
|
start = text.find('```verilog') + len('```verilog') |
|
|
end = text.find('```', start) |
|
|
if end != -1: |
|
|
extracted = text[start:end].strip() |
|
|
return extracted |
|
|
|
|
|
|
|
|
if '```' in text: |
|
|
start = text.find('```') |
|
|
if start != -1: |
|
|
start_marker = text.find('\n', start) |
|
|
if start_marker == -1: |
|
|
start_marker = start + 3 |
|
|
else: |
|
|
start_marker += 1 |
|
|
|
|
|
end = text.find('```', start_marker) |
|
|
if end != -1: |
|
|
extracted = text[start_marker:end].strip() |
|
|
return extracted |
|
|
|
|
|
return text.strip() |
|
|
|
|
|
def generate_with_chat_format(model, tokenizer, instruction, max_new_tokens=1000, temperature=0.1): |
|
|
"""Generate using chat template format (instruction already has chat format)""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inputs = tokenizer(instruction, return_tensors="pt", truncation=True, max_length=1536).to(model.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
do_sample=temperature > 0, |
|
|
top_p=0.95 if temperature > 0 else None, |
|
|
repetition_penalty=1.2, |
|
|
pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
|
|
|
input_length = inputs['input_ids'].shape[1] |
|
|
generated_ids = outputs[0][input_length:] |
|
|
generated_text = tokenizer.decode(generated_ids, skip_special_tokens=False) |
|
|
|
|
|
|
|
|
if generated_text.endswith(tokenizer.eos_token): |
|
|
generated_text = generated_text[:-len(tokenizer.eos_token)].rstrip() |
|
|
|
|
|
return generated_text |
|
|
|
|
|
def analyze_code_quality(generated_text): |
|
|
"""Analyze if generated text is proper Verilog code""" |
|
|
has_module = "module" in generated_text.lower() |
|
|
has_endmodule = "endmodule" in generated_text.lower() |
|
|
has_verilog_keywords = any(kw in generated_text.lower() for kw in ["input", "output", "reg", "wire", "assign", "always"]) |
|
|
has_code_blocks = "```" in generated_text |
|
|
|
|
|
return { |
|
|
"has_module": has_module, |
|
|
"has_endmodule": has_endmodule, |
|
|
"has_verilog_keywords": has_verilog_keywords, |
|
|
"has_code_blocks": has_code_blocks, |
|
|
"is_verilog": has_module and has_endmodule and has_verilog_keywords |
|
|
} |
|
|
|
|
|
def main(): |
|
|
script_dir = Path(__file__).parent |
|
|
model_path = script_dir / "training-outputs" / "codellama-fifo-v2-chat" |
|
|
base_model_path = script_dir / "models" / "base-models" / "CodeLlama-7B-Instruct" |
|
|
train_dataset = script_dir / "datasets" / "processed" / "split_chat_format" / "train.jsonl" |
|
|
|
|
|
print("=" * 80) |
|
|
print("π§ͺ TESTING NEW FINE-TUNED MODEL ON TRAINING SAMPLES") |
|
|
print("=" * 80) |
|
|
print(f"Model: {model_path}") |
|
|
print(f"Dataset: {train_dataset}") |
|
|
print("=" * 80) |
|
|
|
|
|
|
|
|
samples = [] |
|
|
with open(train_dataset, 'r') as f: |
|
|
for i, line in enumerate(f): |
|
|
if i >= 2: |
|
|
break |
|
|
if line.strip(): |
|
|
samples.append(json.loads(line)) |
|
|
|
|
|
if len(samples) < 2: |
|
|
print(f"β Error: Only found {len(samples)} samples in dataset") |
|
|
return |
|
|
|
|
|
|
|
|
print("\nπ¦ Loading model...") |
|
|
model, tokenizer = load_local_model( |
|
|
str(model_path), |
|
|
str(base_model_path) if base_model_path.exists() else None, |
|
|
use_quantization=None, |
|
|
merge_weights=False |
|
|
) |
|
|
print("β
Model loaded!\n") |
|
|
|
|
|
|
|
|
for sample_idx, sample in enumerate(samples, 1): |
|
|
print("\n" + "=" * 80) |
|
|
print(f"π SAMPLE {sample_idx}") |
|
|
print("=" * 80) |
|
|
|
|
|
instruction = sample.get("instruction", "") |
|
|
expected_response = sample.get("response", "") |
|
|
expected_code = extract_code_from_response(expected_response) |
|
|
|
|
|
|
|
|
if "[/INST]" in instruction: |
|
|
user_part = instruction.split("[/INST]")[0] |
|
|
user_part = user_part.split("Generate")[1] if "Generate" in user_part else user_part[-100:] |
|
|
else: |
|
|
user_part = instruction[-200:] |
|
|
|
|
|
print(f"\nπ Task:") |
|
|
print("-" * 80) |
|
|
if "Generate" in user_part: |
|
|
print(user_part.split("Generate")[1].strip()) |
|
|
else: |
|
|
print(user_part[-150:]) |
|
|
print("-" * 80) |
|
|
|
|
|
print(f"\nπ― Expected Response ({len(expected_response)} chars):") |
|
|
print("-" * 80) |
|
|
print(expected_code[:400] + "..." if len(expected_code) > 400 else expected_code) |
|
|
print("-" * 80) |
|
|
|
|
|
|
|
|
print(f"\nπ€ Generating with model...") |
|
|
generated_response = generate_with_chat_format( |
|
|
model, |
|
|
tokenizer, |
|
|
instruction, |
|
|
max_new_tokens=1000, |
|
|
temperature=0.1 |
|
|
) |
|
|
|
|
|
generated_code = extract_code_from_response(generated_response) |
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print(f"β
GENERATED OUTPUT:") |
|
|
print("=" * 80) |
|
|
print(generated_response[:1000] + "..." if len(generated_response) > 1000 else generated_response) |
|
|
print("=" * 80) |
|
|
|
|
|
|
|
|
quality = analyze_code_quality(generated_response) |
|
|
|
|
|
print(f"\nπ Analysis:") |
|
|
print(f" Response length: {len(generated_response)} chars") |
|
|
print(f" Extracted code length: {len(generated_code)} chars") |
|
|
print(f" Contains 'module': {quality['has_module']}") |
|
|
print(f" Contains 'endmodule': {quality['has_endmodule']}") |
|
|
print(f" Contains Verilog keywords: {quality['has_verilog_keywords']}") |
|
|
print(f" Contains code blocks: {quality['has_code_blocks']}") |
|
|
|
|
|
if quality['is_verilog']: |
|
|
print(f" β
STATUS: Generated valid Verilog code!") |
|
|
elif quality['has_module']: |
|
|
print(f" β οΈ STATUS: Partial Verilog code (missing endmodule or keywords)") |
|
|
else: |
|
|
print(f" β STATUS: Not generating Verilog code") |
|
|
|
|
|
print("\n" + "-" * 80) |
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("β
TESTING COMPLETE") |
|
|
print("=" * 80) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|