codellama-fine-tuning / test_single_sample.py
Prithvik-1's picture
Upload test_single_sample.py with huggingface_hub
3aec4f5 verified
#!/usr/bin/env python3
"""
Quick test script to run inference on a single training sample
"""
import json
import sys
from pathlib import Path
# Add scripts to path
sys.path.insert(0, str(Path(__file__).parent / "scripts" / "inference"))
from inference_codellama import load_local_model, generate_with_local_model
def main():
# Paths
script_dir = Path(__file__).parent
model_path = script_dir / "training-outputs" / "codellama-fifo-v1"
base_model_path = script_dir / "models" / "base-models" / "CodeLlama-7B-Instruct"
train_dataset = script_dir / "datasets" / "processed" / "split" / "train.jsonl"
# Load first sample from training data
print("=" * 80)
print("πŸ“š Loading sample from training dataset...")
print("=" * 80)
with open(train_dataset, 'r') as f:
first_line = f.readline()
sample = json.loads(first_line)
instruction = sample.get("instruction", "")
expected_response = sample.get("response", "")
print("\nπŸ“ Instruction:")
print("-" * 80)
print(instruction)
print("-" * 80)
print("\n🎯 Expected Response (first 500 chars):")
print("-" * 80)
print(expected_response[:500])
if len(expected_response) > 500:
print("...")
print("-" * 80)
# Load model
print("\n" + "=" * 80)
print("πŸ“¦ Loading model...")
print("=" * 80)
model, tokenizer = load_local_model(
str(model_path),
str(base_model_path) if base_model_path.exists() else None,
use_quantization=None,
merge_weights=False
)
print("βœ… Model loaded!\n")
# Generate
print("=" * 80)
print("πŸ€– Generating response...")
print("=" * 80)
print()
try:
generated_response = generate_with_local_model(
model,
tokenizer,
instruction,
max_new_tokens=800,
temperature=0.3,
stream=False
)
print("\n" + "=" * 80)
print("βœ… GENERATED OUTPUT:")
print("=" * 80)
print(generated_response)
print("=" * 80)
print(f"\nπŸ“Š Output length: {len(generated_response)} characters")
except Exception as e:
print(f"\n❌ Error: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()