Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any, Callable | |
| from district_llm.prompting import format_district_prompt | |
| from district_llm.repair import RepairConfig, RepairReport, sanitize_action_payload | |
| from district_llm.schema import DistrictAction, DistrictStateSummary | |
| from district_llm.summary_builder import DistrictStateSummaryBuilder | |
| from env.observation_builder import ObservationConfig | |
| from env.reward import RewardConfig | |
| from env.traffic_env import EnvConfig, TrafficEnv | |
| from training.cityflow_dataset import CityFlowDataset | |
| def _extract_json_object(payload: str) -> str: | |
| start = payload.find("{") | |
| end = payload.rfind("}") | |
| if start == -1 or end == -1 or end <= start: | |
| raise ValueError("No JSON object found in model output.") | |
| return payload[start : end + 1] | |
| class DistrictLLMInferenceResult: | |
| action: DistrictAction | |
| raw_text: str | |
| parsed_payload_before_repair: dict[str, Any] | |
| repair_report: RepairReport | |
| json_valid: bool | |
| schema_valid_before_repair: bool | |
| class DistrictLLMInference: | |
| def __init__( | |
| self, | |
| generator_fn: Callable[[str], str] | None = None, | |
| model_name_or_path: str | None = None, | |
| device: str | None = None, | |
| fallback_action: DistrictAction | None = None, | |
| repair_config: RepairConfig | None = None, | |
| ): | |
| self.fallback_action = fallback_action or DistrictAction.default_hold() | |
| self.generator_fn = generator_fn | |
| self.repair_config = repair_config or RepairConfig() | |
| self.tokenizer = None | |
| self.model = None | |
| self.device = device or "cpu" | |
| if self.generator_fn is None: | |
| if not model_name_or_path: | |
| raise ValueError("Provide either generator_fn or model_name_or_path.") | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| model_dir = Path(model_name_or_path) | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) | |
| if self.tokenizer.pad_token_id is None and self.tokenizer.eos_token_id is not None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| if (model_dir / "adapter_config.json").exists(): | |
| try: | |
| from peft import AutoPeftModelForCausalLM | |
| except ImportError as exc: | |
| raise ImportError("Loading a LoRA adapter requires the 'peft' package.") from exc | |
| self.model = AutoPeftModelForCausalLM.from_pretrained(model_name_or_path).to(self.device) | |
| else: | |
| self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path).to(self.device) | |
| self.model.eval() | |
| def generate_raw(self, prompt: str, max_new_tokens: int = 128) -> str: | |
| if self.generator_fn is not None: | |
| return self.generator_fn(prompt) | |
| import torch | |
| assert self.model is not None and self.tokenizer is not None | |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=False, | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| ) | |
| generated = outputs[0][inputs["input_ids"].shape[1] :] | |
| return self.tokenizer.decode(generated, skip_special_tokens=True) | |
| def parse_action( | |
| self, | |
| payload: str, | |
| summary: DistrictStateSummary | None = None, | |
| ) -> tuple[DistrictAction, RepairReport, dict[str, Any], bool, bool]: | |
| json_valid = True | |
| schema_valid_before_repair = True | |
| try: | |
| parsed_payload = json.loads(_extract_json_object(payload)) | |
| except Exception: | |
| json_valid = False | |
| schema_valid_before_repair = False | |
| parsed_payload = self.fallback_action.to_dict() | |
| action, repair_report = sanitize_action_payload( | |
| payload=parsed_payload, | |
| summary=summary, | |
| config=self.repair_config, | |
| ) | |
| return action, repair_report, parsed_payload, json_valid, schema_valid_before_repair | |
| def predict_with_result( | |
| self, | |
| summary: DistrictStateSummary, | |
| max_new_tokens: int = 128, | |
| ) -> DistrictLLMInferenceResult: | |
| prompt = format_district_prompt( | |
| summary, | |
| max_target_intersections=self.repair_config.max_target_intersections, | |
| allow_only_visible_candidates=self.repair_config.allow_only_visible_candidates, | |
| ) | |
| raw = self.generate_raw(prompt=prompt, max_new_tokens=max_new_tokens) | |
| action, repair_report, parsed_payload, json_valid, schema_valid_before_repair = self.parse_action( | |
| raw, | |
| summary=summary, | |
| ) | |
| return DistrictLLMInferenceResult( | |
| action=action, | |
| raw_text=raw, | |
| parsed_payload_before_repair=parsed_payload, | |
| repair_report=repair_report, | |
| json_valid=json_valid, | |
| schema_valid_before_repair=schema_valid_before_repair, | |
| ) | |
| def predict(self, summary: DistrictStateSummary, max_new_tokens: int = 128) -> DistrictAction: | |
| return self.predict_with_result(summary=summary, max_new_tokens=max_new_tokens).action | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Run single-sample district LLM inference.") | |
| parser.add_argument("--model", required=True, help="Model name, local path, or LoRA adapter path.") | |
| parser.add_argument("--generated-root", default="data/generated") | |
| parser.add_argument("--splits-root", default="data/splits") | |
| parser.add_argument("--city-id", required=True) | |
| parser.add_argument("--scenario-name", required=True) | |
| parser.add_argument("--district-id", required=True) | |
| parser.add_argument("--device", default=None) | |
| parser.add_argument("--max-new-tokens", type=int, default=128) | |
| parser.add_argument( | |
| "--allow-only-visible-candidates", | |
| action=argparse.BooleanOptionalAction, | |
| default=True, | |
| ) | |
| parser.add_argument("--max-target-intersections", type=int, default=3) | |
| parser.add_argument( | |
| "--fallback-on-empty-targets", | |
| action=argparse.BooleanOptionalAction, | |
| default=True, | |
| ) | |
| parser.add_argument( | |
| "--fallback-mode", | |
| choices=("heuristic", "hold", "none"), | |
| default="heuristic", | |
| ) | |
| return parser.parse_args() | |
| def build_env(scenario_spec) -> TrafficEnv: | |
| env_config = EnvConfig( | |
| simulator_interval=1, | |
| decision_interval=5, | |
| min_green_time=10, | |
| thread_num=1, | |
| observation=ObservationConfig(), | |
| reward=RewardConfig(variant="wait_queue_throughput"), | |
| ) | |
| return TrafficEnv( | |
| city_id=scenario_spec.city_id, | |
| scenario_name=scenario_spec.scenario_name, | |
| city_dir=scenario_spec.city_dir, | |
| scenario_dir=scenario_spec.scenario_dir, | |
| config_path=scenario_spec.config_path, | |
| roadnet_path=scenario_spec.roadnet_path, | |
| district_map_path=scenario_spec.district_map_path, | |
| metadata_path=scenario_spec.metadata_path, | |
| env_config=env_config, | |
| ) | |
| def main() -> None: | |
| args = parse_args() | |
| dataset = CityFlowDataset( | |
| generated_root=args.generated_root, | |
| splits_root=args.splits_root, | |
| ) | |
| scenario_spec = dataset.build_scenario_spec(args.city_id, args.scenario_name) | |
| env = build_env(scenario_spec) | |
| summary_builder = DistrictStateSummaryBuilder(candidate_limit=max(6, args.max_target_intersections)) | |
| observation_batch = env.reset() | |
| summaries = summary_builder.build_all(env, observation_batch) | |
| if args.district_id not in summaries: | |
| raise ValueError(f"Unknown district_id '{args.district_id}' for {args.city_id}/{args.scenario_name}.") | |
| inference = DistrictLLMInference( | |
| model_name_or_path=args.model, | |
| device=args.device, | |
| fallback_action=DistrictAction.default_hold(), | |
| repair_config=RepairConfig( | |
| allow_only_visible_candidates=args.allow_only_visible_candidates, | |
| max_target_intersections=args.max_target_intersections, | |
| fallback_on_empty_targets=args.fallback_on_empty_targets, | |
| fallback_mode=args.fallback_mode, | |
| ), | |
| ) | |
| action = inference.predict( | |
| summary=summaries[args.district_id], | |
| max_new_tokens=args.max_new_tokens, | |
| ) | |
| print(action.to_pretty_json()) | |
| if __name__ == "__main__": | |
| main() | |