|
|
|
|
|
""" |
|
|
Quick test script to run inference on a single training sample |
|
|
""" |
|
|
|
|
|
import json |
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent / "scripts" / "inference")) |
|
|
|
|
|
from inference_codellama import load_local_model, generate_with_local_model |
|
|
|
|
|
def main(): |
|
|
|
|
|
script_dir = Path(__file__).parent |
|
|
model_path = script_dir / "training-outputs" / "codellama-fifo-v1" |
|
|
base_model_path = script_dir / "models" / "base-models" / "CodeLlama-7B-Instruct" |
|
|
train_dataset = script_dir / "datasets" / "processed" / "split" / "train.jsonl" |
|
|
|
|
|
|
|
|
print("=" * 80) |
|
|
print("π Loading sample from training dataset...") |
|
|
print("=" * 80) |
|
|
|
|
|
with open(train_dataset, 'r') as f: |
|
|
first_line = f.readline() |
|
|
sample = json.loads(first_line) |
|
|
|
|
|
instruction = sample.get("instruction", "") |
|
|
expected_response = sample.get("response", "") |
|
|
|
|
|
print("\nπ Instruction:") |
|
|
print("-" * 80) |
|
|
print(instruction) |
|
|
print("-" * 80) |
|
|
|
|
|
print("\nπ― Expected Response (first 500 chars):") |
|
|
print("-" * 80) |
|
|
print(expected_response[:500]) |
|
|
if len(expected_response) > 500: |
|
|
print("...") |
|
|
print("-" * 80) |
|
|
|
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("π¦ Loading model...") |
|
|
print("=" * 80) |
|
|
|
|
|
model, tokenizer = load_local_model( |
|
|
str(model_path), |
|
|
str(base_model_path) if base_model_path.exists() else None, |
|
|
use_quantization=None, |
|
|
merge_weights=False |
|
|
) |
|
|
|
|
|
print("β
Model loaded!\n") |
|
|
|
|
|
|
|
|
print("=" * 80) |
|
|
print("π€ Generating response...") |
|
|
print("=" * 80) |
|
|
print() |
|
|
|
|
|
try: |
|
|
generated_response = generate_with_local_model( |
|
|
model, |
|
|
tokenizer, |
|
|
instruction, |
|
|
max_new_tokens=800, |
|
|
temperature=0.3, |
|
|
stream=False |
|
|
) |
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("β
GENERATED OUTPUT:") |
|
|
print("=" * 80) |
|
|
print(generated_response) |
|
|
print("=" * 80) |
|
|
print(f"\nπ Output length: {len(generated_response)} characters") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\nβ Error: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|