File size: 2,382 Bytes
3aec4f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
#!/usr/bin/env python3
"""
Quick test script to run inference on a single training sample
"""
import json
import sys
from pathlib import Path
# Add scripts to path
sys.path.insert(0, str(Path(__file__).parent / "scripts" / "inference"))
from inference_codellama import load_local_model, generate_with_local_model
def main():
# Paths
script_dir = Path(__file__).parent
model_path = script_dir / "training-outputs" / "codellama-fifo-v1"
base_model_path = script_dir / "models" / "base-models" / "CodeLlama-7B-Instruct"
train_dataset = script_dir / "datasets" / "processed" / "split" / "train.jsonl"
# Load first sample from training data
print("=" * 80)
print("๐ Loading sample from training dataset...")
print("=" * 80)
with open(train_dataset, 'r') as f:
first_line = f.readline()
sample = json.loads(first_line)
instruction = sample.get("instruction", "")
expected_response = sample.get("response", "")
print("\n๐ Instruction:")
print("-" * 80)
print(instruction)
print("-" * 80)
print("\n๐ฏ Expected Response (first 500 chars):")
print("-" * 80)
print(expected_response[:500])
if len(expected_response) > 500:
print("...")
print("-" * 80)
# Load model
print("\n" + "=" * 80)
print("๐ฆ Loading model...")
print("=" * 80)
model, tokenizer = load_local_model(
str(model_path),
str(base_model_path) if base_model_path.exists() else None,
use_quantization=None,
merge_weights=False
)
print("โ
Model loaded!\n")
# Generate
print("=" * 80)
print("๐ค Generating response...")
print("=" * 80)
print()
try:
generated_response = generate_with_local_model(
model,
tokenizer,
instruction,
max_new_tokens=800,
temperature=0.3,
stream=False
)
print("\n" + "=" * 80)
print("โ
GENERATED OUTPUT:")
print("=" * 80)
print(generated_response)
print("=" * 80)
print(f"\n๐ Output length: {len(generated_response)} characters")
except Exception as e:
print(f"\nโ Error: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()
|