File size: 3,277 Bytes
1137e50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from __future__ import annotations

import argparse
import json
from pathlib import Path

from routercore.model_router import extract_first_json_object
from training.format_dataset import build_inference_prompt
from training.train_lora import OptionalTrainingDependencyError


def load_inference_dependencies():
    try:
        import torch
        from peft import PeftModel
        from transformers import AutoModelForCausalLM, AutoTokenizer
    except ImportError as exc:
        raise OptionalTrainingDependencyError(
            "Optional inference dependencies are unavailable. Install transformers, peft, and torch."
        ) from exc
    return {
        "torch": torch,
        "PeftModel": PeftModel,
        "AutoModelForCausalLM": AutoModelForCausalLM,
        "AutoTokenizer": AutoTokenizer,
    }


def run_lora_inference(
    *,
    base_model: str,
    adapter: Path,
    user_input: str,
    max_new_tokens: int,
) -> str:
    deps = load_inference_dependencies()
    torch = deps["torch"]
    PeftModel = deps["PeftModel"]
    AutoModelForCausalLM = deps["AutoModelForCausalLM"]
    AutoTokenizer = deps["AutoTokenizer"]

    device = "cuda" if torch.cuda.is_available() else "cpu"
    try:
        tokenizer = AutoTokenizer.from_pretrained(adapter if adapter.exists() else base_model)
    except Exception:
        tokenizer = AutoTokenizer.from_pretrained(base_model)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(base_model)
    model = PeftModel.from_pretrained(model, adapter)
    model.to(device)
    model.eval()

    prompt = build_inference_prompt(user_input)
    encoded = tokenizer(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        output_ids = model.generate(
            **encoded,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id,
        )
    input_length = encoded["input_ids"].shape[-1]
    return tokenizer.decode(output_ids[0][input_length:], skip_special_tokens=True)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Run inference with a RouterCore LoRA adapter.")
    parser.add_argument("--base-model", required=True)
    parser.add_argument("--adapter", type=Path, required=True)
    parser.add_argument("--input", required=True)
    parser.add_argument("--max-new-tokens", type=int, default=512)
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    try:
        raw_output = run_lora_inference(
            base_model=args.base_model,
            adapter=args.adapter,
            user_input=args.input,
            max_new_tokens=args.max_new_tokens,
        )
    except OptionalTrainingDependencyError as exc:
        print(str(exc))
        print("Skipping LoRA inference. Run `pip install transformers peft torch` to enable it.")
        return

    print("Raw model output:")
    print(raw_output)
    parsed = extract_first_json_object(raw_output)
    print("\nParsed JSON:")
    if parsed is None:
        print("Parse failed: no valid JSON object found.")
    else:
        print(json.dumps(parsed, indent=2, sort_keys=True))


if __name__ == "__main__":
    main()