#!/usr/bin/env python3 """ Test inference on a single training sample with exact training format """ import json import sys from pathlib import Path # Add scripts to path sys.path.insert(0, str(Path(__file__).parent / "scripts" / "inference")) from inference_codellama import load_local_model import torch def generate_with_exact_format(model, tokenizer, instruction, max_new_tokens=800, temperature=0.1): """Generate using EXACT training format: instruction + EOS (model continues from here)""" # Use EXACT training format: instruction + EOS token # During training: instruction + EOS + response + EOS # During inference: instruction + EOS (model will generate response) prompt = f"{instruction}{tokenizer.eos_token}" print(f"\n๐Ÿ“ Prompt Format (matching training):") print(f" Length: {len(prompt)} chars") print(f" First 200 chars: {prompt[:200]}...") print() inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1536).to(model.device) print(f"๐Ÿ“Š Tokenized:") print(f" Input tokens: {inputs['input_ids'].shape[1]}") print() print("๐Ÿค– Generating...") print("=" * 80) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=temperature > 0, top_p=0.9 if temperature > 0 else None, repetition_penalty=1.2, # Higher to prevent repetition pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, ) # Decode only the newly generated tokens (after the prompt) generated_ids = outputs[0][inputs['input_ids'].shape[1]:] generated_text = tokenizer.decode(generated_ids, skip_special_tokens=False) # Remove EOS token if present at the end if generated_text.endswith(tokenizer.eos_token): generated_text = generated_text[:-len(tokenizer.eos_token)].rstrip() return generated_text def extract_code_from_response(text): """Extract Verilog code from markdown code blocks""" if not text: return text # Check for verilog code block if '```verilog' in text: start = text.find('```verilog') + len('```verilog') end = text.find('```', start) if end != -1: extracted = text[start:end].strip() return extracted # Check for generic code block if '```' in text: start = text.find('```') if start != -1: start_marker = text.find('\n', start) if start_marker == -1: start_marker = start + 3 else: start_marker += 1 end = text.find('```', start_marker) if end != -1: extracted = text[start_marker:end].strip() return extracted return text.strip() def main(): # Paths script_dir = Path(__file__).parent model_path = script_dir / "training-outputs" / "codellama-fifo-v1" base_model_path = script_dir / "models" / "base-models" / "CodeLlama-7B-Instruct" train_dataset = script_dir / "datasets" / "processed" / "split" / "train.jsonl" print("=" * 80) print("๐Ÿงช TESTING SINGLE TRAINING SAMPLE (EXACT TRAINING FORMAT)") print("=" * 80) print(f"Model: {model_path}") print(f"Base: {base_model_path}") print("=" * 80) # Load first sample print("\n๐Ÿ“š Loading training sample #1...") with open(train_dataset, 'r') as f: first_line = f.readline() sample = json.loads(first_line) instruction = sample.get("instruction", "") expected_response = sample.get("response", "") expected_code = extract_code_from_response(expected_response) print(f"\n๐Ÿ“ Instruction ({len(instruction)} chars):") print("-" * 80) print(instruction) print("-" * 80) print(f"\n๐ŸŽฏ Expected Response ({len(expected_response)} chars):") print("-" * 80) print(expected_response[:500] + "..." if len(expected_response) > 500 else expected_response) print("-" * 80) # Load model print("\n๐Ÿ“ฆ Loading model...") model, tokenizer = load_local_model( str(model_path), str(base_model_path) if base_model_path.exists() else None, use_quantization=None, merge_weights=False ) print("โœ… Model loaded!\n") # Test with different temperatures temperatures = [0.1, 0.2, 0.3] for temp in temperatures: print("\n" + "=" * 80) print(f"๐Ÿ”ฅ TESTING WITH TEMPERATURE: {temp}") print("=" * 80) try: generated_response = generate_with_exact_format( model, tokenizer, instruction, max_new_tokens=800, temperature=temp ) generated_code = extract_code_from_response(generated_response) print("\n" + "=" * 80) print(f"โœ… GENERATED OUTPUT (Temperature {temp}):") print("=" * 80) print(generated_response) print("=" * 80) print(f"\n๐Ÿ“Š Statistics:") print(f" Full response length: {len(generated_response)} chars") print(f" Extracted code length: {len(generated_code)} chars") print(f" Expected code length: {len(expected_code)} chars") # Quick check if it contains module declaration has_module = "module" in generated_response.lower() has_endmodule = "endmodule" in generated_response.lower() has_verilog_code = "```verilog" in generated_response or ("module" in generated_response and "input" in generated_response) print(f"\nโœ… Code Quality Check:") print(f" Contains 'module': {has_module}") print(f" Contains 'endmodule': {has_endmodule}") print(f" Looks like Verilog code: {has_verilog_code}") if has_verilog_code and has_endmodule: print(f" โœ… STATUS: Generated Verilog code!") elif has_module: print(f" โš ๏ธ STATUS: Partial code (missing endmodule or full implementation)") else: print(f" โŒ STATUS: Not generating code (generating text instead)") except Exception as e: print(f"โŒ Error with temperature {temp}: {e}") import traceback traceback.print_exc() if __name__ == "__main__": main()