from __future__ import annotations import argparse import json from pathlib import Path from routercore.model_router import extract_first_json_object from training.format_dataset import build_inference_prompt from training.train_lora import OptionalTrainingDependencyError def load_inference_dependencies(): try: import torch from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer except ImportError as exc: raise OptionalTrainingDependencyError( "Optional inference dependencies are unavailable. Install transformers, peft, and torch." ) from exc return { "torch": torch, "PeftModel": PeftModel, "AutoModelForCausalLM": AutoModelForCausalLM, "AutoTokenizer": AutoTokenizer, } def run_lora_inference( *, base_model: str, adapter: Path, user_input: str, max_new_tokens: int, ) -> str: deps = load_inference_dependencies() torch = deps["torch"] PeftModel = deps["PeftModel"] AutoModelForCausalLM = deps["AutoModelForCausalLM"] AutoTokenizer = deps["AutoTokenizer"] device = "cuda" if torch.cuda.is_available() else "cpu" try: tokenizer = AutoTokenizer.from_pretrained(adapter if adapter.exists() else base_model) except Exception: tokenizer = AutoTokenizer.from_pretrained(base_model) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained(base_model) model = PeftModel.from_pretrained(model, adapter) model.to(device) model.eval() prompt = build_inference_prompt(user_input) encoded = tokenizer(prompt, return_tensors="pt").to(device) with torch.no_grad(): output_ids = model.generate( **encoded, max_new_tokens=max_new_tokens, do_sample=False, pad_token_id=tokenizer.eos_token_id, ) input_length = encoded["input_ids"].shape[-1] return tokenizer.decode(output_ids[0][input_length:], skip_special_tokens=True) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Run inference with a RouterCore LoRA adapter.") parser.add_argument("--base-model", required=True) parser.add_argument("--adapter", type=Path, required=True) parser.add_argument("--input", required=True) parser.add_argument("--max-new-tokens", type=int, default=512) return parser.parse_args() def main() -> None: args = parse_args() try: raw_output = run_lora_inference( base_model=args.base_model, adapter=args.adapter, user_input=args.input, max_new_tokens=args.max_new_tokens, ) except OptionalTrainingDependencyError as exc: print(str(exc)) print("Skipping LoRA inference. Run `pip install transformers peft torch` to enable it.") return print("Raw model output:") print(raw_output) parsed = extract_first_json_object(raw_output) print("\nParsed JSON:") if parsed is None: print("Parse failed: no valid JSON object found.") else: print(json.dumps(parsed, indent=2, sort_keys=True)) if __name__ == "__main__": main()