codellama-fine-tuning / test_new_model.py
Prithvik-1's picture
Upload test_new_model.py with huggingface_hub
99416ae verified
#!/usr/bin/env python3
"""
Test the newly fine-tuned CodeLlama model on training samples
"""
import json
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent / "scripts" / "inference"))
from inference_codellama import load_local_model
import torch
from transformers import AutoTokenizer
import re
def extract_code_from_response(text):
"""Extract Verilog code from markdown code blocks"""
if not text:
return text
# Check for verilog code block
if '```verilog' in text:
start = text.find('```verilog') + len('```verilog')
end = text.find('```', start)
if end != -1:
extracted = text[start:end].strip()
return extracted
# Check for generic code block
if '```' in text:
start = text.find('```')
if start != -1:
start_marker = text.find('\n', start)
if start_marker == -1:
start_marker = start + 3
else:
start_marker += 1
end = text.find('```', start_marker)
if end != -1:
extracted = text[start_marker:end].strip()
return extracted
return text.strip()
def generate_with_chat_format(model, tokenizer, instruction, max_new_tokens=1000, temperature=0.1):
"""Generate using chat template format (instruction already has chat format)"""
# Instruction already contains: <s>[INST]...[/INST]
# We just append response + EOS during training
# During inference: instruction (ends with [/INST]) β†’ model generates response
inputs = tokenizer(instruction, return_tensors="pt", truncation=True, max_length=1536).to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=temperature > 0,
top_p=0.95 if temperature > 0 else None,
repetition_penalty=1.2,
pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
# Decode only new tokens
input_length = inputs['input_ids'].shape[1]
generated_ids = outputs[0][input_length:]
generated_text = tokenizer.decode(generated_ids, skip_special_tokens=False)
# Remove trailing EOS if present
if generated_text.endswith(tokenizer.eos_token):
generated_text = generated_text[:-len(tokenizer.eos_token)].rstrip()
return generated_text
def analyze_code_quality(generated_text):
"""Analyze if generated text is proper Verilog code"""
has_module = "module" in generated_text.lower()
has_endmodule = "endmodule" in generated_text.lower()
has_verilog_keywords = any(kw in generated_text.lower() for kw in ["input", "output", "reg", "wire", "assign", "always"])
has_code_blocks = "```" in generated_text
return {
"has_module": has_module,
"has_endmodule": has_endmodule,
"has_verilog_keywords": has_verilog_keywords,
"has_code_blocks": has_code_blocks,
"is_verilog": has_module and has_endmodule and has_verilog_keywords
}
def main():
script_dir = Path(__file__).parent
model_path = script_dir / "training-outputs" / "codellama-fifo-v2-chat"
base_model_path = script_dir / "models" / "base-models" / "CodeLlama-7B-Instruct"
train_dataset = script_dir / "datasets" / "processed" / "split_chat_format" / "train.jsonl"
print("=" * 80)
print("πŸ§ͺ TESTING NEW FINE-TUNED MODEL ON TRAINING SAMPLES")
print("=" * 80)
print(f"Model: {model_path}")
print(f"Dataset: {train_dataset}")
print("=" * 80)
# Load two samples
samples = []
with open(train_dataset, 'r') as f:
for i, line in enumerate(f):
if i >= 2: # Get first 2 samples
break
if line.strip():
samples.append(json.loads(line))
if len(samples) < 2:
print(f"❌ Error: Only found {len(samples)} samples in dataset")
return
# Load model
print("\nπŸ“¦ Loading model...")
model, tokenizer = load_local_model(
str(model_path),
str(base_model_path) if base_model_path.exists() else None,
use_quantization=None,
merge_weights=False
)
print("βœ… Model loaded!\n")
# Test each sample
for sample_idx, sample in enumerate(samples, 1):
print("\n" + "=" * 80)
print(f"πŸ“ SAMPLE {sample_idx}")
print("=" * 80)
instruction = sample.get("instruction", "")
expected_response = sample.get("response", "")
expected_code = extract_code_from_response(expected_response)
# Extract user message from instruction for display
if "[/INST]" in instruction:
user_part = instruction.split("[/INST]")[0]
user_part = user_part.split("Generate")[1] if "Generate" in user_part else user_part[-100:]
else:
user_part = instruction[-200:]
print(f"\nπŸ“‹ Task:")
print("-" * 80)
if "Generate" in user_part:
print(user_part.split("Generate")[1].strip())
else:
print(user_part[-150:])
print("-" * 80)
print(f"\n🎯 Expected Response ({len(expected_response)} chars):")
print("-" * 80)
print(expected_code[:400] + "..." if len(expected_code) > 400 else expected_code)
print("-" * 80)
# Generate
print(f"\nπŸ€– Generating with model...")
generated_response = generate_with_chat_format(
model,
tokenizer,
instruction,
max_new_tokens=1000,
temperature=0.1
)
generated_code = extract_code_from_response(generated_response)
print("\n" + "=" * 80)
print(f"βœ… GENERATED OUTPUT:")
print("=" * 80)
print(generated_response[:1000] + "..." if len(generated_response) > 1000 else generated_response)
print("=" * 80)
# Analysis
quality = analyze_code_quality(generated_response)
print(f"\nπŸ“Š Analysis:")
print(f" Response length: {len(generated_response)} chars")
print(f" Extracted code length: {len(generated_code)} chars")
print(f" Contains 'module': {quality['has_module']}")
print(f" Contains 'endmodule': {quality['has_endmodule']}")
print(f" Contains Verilog keywords: {quality['has_verilog_keywords']}")
print(f" Contains code blocks: {quality['has_code_blocks']}")
if quality['is_verilog']:
print(f" βœ… STATUS: Generated valid Verilog code!")
elif quality['has_module']:
print(f" ⚠️ STATUS: Partial Verilog code (missing endmodule or keywords)")
else:
print(f" ❌ STATUS: Not generating Verilog code")
print("\n" + "-" * 80)
print("\n" + "=" * 80)
print("βœ… TESTING COMPLETE")
print("=" * 80)
if __name__ == "__main__":
main()