from __future__ import annotations import argparse import json from dataclasses import dataclass from pathlib import Path from typing import Any, Callable from district_llm.prompting import format_district_prompt from district_llm.repair import RepairConfig, RepairReport, sanitize_action_payload from district_llm.schema import DistrictAction, DistrictStateSummary from district_llm.summary_builder import DistrictStateSummaryBuilder from env.observation_builder import ObservationConfig from env.reward import RewardConfig from env.traffic_env import EnvConfig, TrafficEnv from training.cityflow_dataset import CityFlowDataset def _extract_json_object(payload: str) -> str: start = payload.find("{") end = payload.rfind("}") if start == -1 or end == -1 or end <= start: raise ValueError("No JSON object found in model output.") return payload[start : end + 1] @dataclass(frozen=True) class DistrictLLMInferenceResult: action: DistrictAction raw_text: str parsed_payload_before_repair: dict[str, Any] repair_report: RepairReport json_valid: bool schema_valid_before_repair: bool class DistrictLLMInference: def __init__( self, generator_fn: Callable[[str], str] | None = None, model_name_or_path: str | None = None, device: str | None = None, fallback_action: DistrictAction | None = None, repair_config: RepairConfig | None = None, ): self.fallback_action = fallback_action or DistrictAction.default_hold() self.generator_fn = generator_fn self.repair_config = repair_config or RepairConfig() self.tokenizer = None self.model = None self.device = device or "cpu" if self.generator_fn is None: if not model_name_or_path: raise ValueError("Provide either generator_fn or model_name_or_path.") import torch from transformers import AutoModelForCausalLM, AutoTokenizer self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") model_dir = Path(model_name_or_path) self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) if self.tokenizer.pad_token_id is None and self.tokenizer.eos_token_id is not None: self.tokenizer.pad_token = self.tokenizer.eos_token if (model_dir / "adapter_config.json").exists(): try: from peft import AutoPeftModelForCausalLM except ImportError as exc: raise ImportError("Loading a LoRA adapter requires the 'peft' package.") from exc self.model = AutoPeftModelForCausalLM.from_pretrained(model_name_or_path).to(self.device) else: self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path).to(self.device) self.model.eval() def generate_raw(self, prompt: str, max_new_tokens: int = 128) -> str: if self.generator_fn is not None: return self.generator_fn(prompt) import torch assert self.model is not None and self.tokenizer is not None inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False, pad_token_id=self.tokenizer.eos_token_id, ) generated = outputs[0][inputs["input_ids"].shape[1] :] return self.tokenizer.decode(generated, skip_special_tokens=True) def parse_action( self, payload: str, summary: DistrictStateSummary | None = None, ) -> tuple[DistrictAction, RepairReport, dict[str, Any], bool, bool]: json_valid = True schema_valid_before_repair = True try: parsed_payload = json.loads(_extract_json_object(payload)) except Exception: json_valid = False schema_valid_before_repair = False parsed_payload = self.fallback_action.to_dict() action, repair_report = sanitize_action_payload( payload=parsed_payload, summary=summary, config=self.repair_config, ) return action, repair_report, parsed_payload, json_valid, schema_valid_before_repair def predict_with_result( self, summary: DistrictStateSummary, max_new_tokens: int = 128, ) -> DistrictLLMInferenceResult: prompt = format_district_prompt( summary, max_target_intersections=self.repair_config.max_target_intersections, allow_only_visible_candidates=self.repair_config.allow_only_visible_candidates, ) raw = self.generate_raw(prompt=prompt, max_new_tokens=max_new_tokens) action, repair_report, parsed_payload, json_valid, schema_valid_before_repair = self.parse_action( raw, summary=summary, ) return DistrictLLMInferenceResult( action=action, raw_text=raw, parsed_payload_before_repair=parsed_payload, repair_report=repair_report, json_valid=json_valid, schema_valid_before_repair=schema_valid_before_repair, ) def predict(self, summary: DistrictStateSummary, max_new_tokens: int = 128) -> DistrictAction: return self.predict_with_result(summary=summary, max_new_tokens=max_new_tokens).action def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Run single-sample district LLM inference.") parser.add_argument("--model", required=True, help="Model name, local path, or LoRA adapter path.") parser.add_argument("--generated-root", default="data/generated") parser.add_argument("--splits-root", default="data/splits") parser.add_argument("--city-id", required=True) parser.add_argument("--scenario-name", required=True) parser.add_argument("--district-id", required=True) parser.add_argument("--device", default=None) parser.add_argument("--max-new-tokens", type=int, default=128) parser.add_argument( "--allow-only-visible-candidates", action=argparse.BooleanOptionalAction, default=True, ) parser.add_argument("--max-target-intersections", type=int, default=3) parser.add_argument( "--fallback-on-empty-targets", action=argparse.BooleanOptionalAction, default=True, ) parser.add_argument( "--fallback-mode", choices=("heuristic", "hold", "none"), default="heuristic", ) return parser.parse_args() def build_env(scenario_spec) -> TrafficEnv: env_config = EnvConfig( simulator_interval=1, decision_interval=5, min_green_time=10, thread_num=1, observation=ObservationConfig(), reward=RewardConfig(variant="wait_queue_throughput"), ) return TrafficEnv( city_id=scenario_spec.city_id, scenario_name=scenario_spec.scenario_name, city_dir=scenario_spec.city_dir, scenario_dir=scenario_spec.scenario_dir, config_path=scenario_spec.config_path, roadnet_path=scenario_spec.roadnet_path, district_map_path=scenario_spec.district_map_path, metadata_path=scenario_spec.metadata_path, env_config=env_config, ) def main() -> None: args = parse_args() dataset = CityFlowDataset( generated_root=args.generated_root, splits_root=args.splits_root, ) scenario_spec = dataset.build_scenario_spec(args.city_id, args.scenario_name) env = build_env(scenario_spec) summary_builder = DistrictStateSummaryBuilder(candidate_limit=max(6, args.max_target_intersections)) observation_batch = env.reset() summaries = summary_builder.build_all(env, observation_batch) if args.district_id not in summaries: raise ValueError(f"Unknown district_id '{args.district_id}' for {args.city_id}/{args.scenario_name}.") inference = DistrictLLMInference( model_name_or_path=args.model, device=args.device, fallback_action=DistrictAction.default_hold(), repair_config=RepairConfig( allow_only_visible_candidates=args.allow_only_visible_candidates, max_target_intersections=args.max_target_intersections, fallback_on_empty_targets=args.fallback_on_empty_targets, fallback_mode=args.fallback_mode, ), ) action = inference.predict( summary=summaries[args.district_id], max_new_tokens=args.max_new_tokens, ) print(action.to_pretty_json()) if __name__ == "__main__": main()