| | |
| | """ |
| | Quick test script to run inference on a single training sample |
| | """ |
| |
|
| | import json |
| | import sys |
| | from pathlib import Path |
| |
|
| | |
| | sys.path.insert(0, str(Path(__file__).parent / "scripts" / "inference")) |
| |
|
| | from inference_codellama import load_local_model, generate_with_local_model |
| |
|
| | def main(): |
| | |
| | script_dir = Path(__file__).parent |
| | model_path = script_dir / "training-outputs" / "codellama-fifo-v1" |
| | base_model_path = script_dir / "models" / "base-models" / "CodeLlama-7B-Instruct" |
| | train_dataset = script_dir / "datasets" / "processed" / "split" / "train.jsonl" |
| | |
| | |
| | print("=" * 80) |
| | print("π Loading sample from training dataset...") |
| | print("=" * 80) |
| | |
| | with open(train_dataset, 'r') as f: |
| | first_line = f.readline() |
| | sample = json.loads(first_line) |
| | |
| | instruction = sample.get("instruction", "") |
| | expected_response = sample.get("response", "") |
| | |
| | print("\nπ Instruction:") |
| | print("-" * 80) |
| | print(instruction) |
| | print("-" * 80) |
| | |
| | print("\nπ― Expected Response (first 500 chars):") |
| | print("-" * 80) |
| | print(expected_response[:500]) |
| | if len(expected_response) > 500: |
| | print("...") |
| | print("-" * 80) |
| | |
| | |
| | print("\n" + "=" * 80) |
| | print("π¦ Loading model...") |
| | print("=" * 80) |
| | |
| | model, tokenizer = load_local_model( |
| | str(model_path), |
| | str(base_model_path) if base_model_path.exists() else None, |
| | use_quantization=None, |
| | merge_weights=False |
| | ) |
| | |
| | print("β
Model loaded!\n") |
| | |
| | |
| | print("=" * 80) |
| | print("π€ Generating response...") |
| | print("=" * 80) |
| | print() |
| | |
| | try: |
| | generated_response = generate_with_local_model( |
| | model, |
| | tokenizer, |
| | instruction, |
| | max_new_tokens=800, |
| | temperature=0.3, |
| | stream=False |
| | ) |
| | |
| | print("\n" + "=" * 80) |
| | print("β
GENERATED OUTPUT:") |
| | print("=" * 80) |
| | print(generated_response) |
| | print("=" * 80) |
| | print(f"\nπ Output length: {len(generated_response)} characters") |
| | |
| | except Exception as e: |
| | print(f"\nβ Error: {e}") |
| | import traceback |
| | traceback.print_exc() |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|
| |
|