routercore / training /inference_lora.py
Jayteare's picture
Deploy RouterCore Gradio demo
1137e50 verified
from __future__ import annotations
import argparse
import json
from pathlib import Path
from routercore.model_router import extract_first_json_object
from training.format_dataset import build_inference_prompt
from training.train_lora import OptionalTrainingDependencyError
def load_inference_dependencies():
try:
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
except ImportError as exc:
raise OptionalTrainingDependencyError(
"Optional inference dependencies are unavailable. Install transformers, peft, and torch."
) from exc
return {
"torch": torch,
"PeftModel": PeftModel,
"AutoModelForCausalLM": AutoModelForCausalLM,
"AutoTokenizer": AutoTokenizer,
}
def run_lora_inference(
*,
base_model: str,
adapter: Path,
user_input: str,
max_new_tokens: int,
) -> str:
deps = load_inference_dependencies()
torch = deps["torch"]
PeftModel = deps["PeftModel"]
AutoModelForCausalLM = deps["AutoModelForCausalLM"]
AutoTokenizer = deps["AutoTokenizer"]
device = "cuda" if torch.cuda.is_available() else "cpu"
try:
tokenizer = AutoTokenizer.from_pretrained(adapter if adapter.exists() else base_model)
except Exception:
tokenizer = AutoTokenizer.from_pretrained(base_model)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(base_model)
model = PeftModel.from_pretrained(model, adapter)
model.to(device)
model.eval()
prompt = build_inference_prompt(user_input)
encoded = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
output_ids = model.generate(
**encoded,
max_new_tokens=max_new_tokens,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
)
input_length = encoded["input_ids"].shape[-1]
return tokenizer.decode(output_ids[0][input_length:], skip_special_tokens=True)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run inference with a RouterCore LoRA adapter.")
parser.add_argument("--base-model", required=True)
parser.add_argument("--adapter", type=Path, required=True)
parser.add_argument("--input", required=True)
parser.add_argument("--max-new-tokens", type=int, default=512)
return parser.parse_args()
def main() -> None:
args = parse_args()
try:
raw_output = run_lora_inference(
base_model=args.base_model,
adapter=args.adapter,
user_input=args.input,
max_new_tokens=args.max_new_tokens,
)
except OptionalTrainingDependencyError as exc:
print(str(exc))
print("Skipping LoRA inference. Run `pip install transformers peft torch` to enable it.")
return
print("Raw model output:")
print(raw_output)
parsed = extract_first_json_object(raw_output)
print("\nParsed JSON:")
if parsed is None:
print("Parse failed: no valid JSON object found.")
else:
print(json.dumps(parsed, indent=2, sort_keys=True))
if __name__ == "__main__":
main()