tokev's picture
Add files using upload-large-folder tool
5893134 verified
from __future__ import annotations
import argparse
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable
from district_llm.prompting import format_district_prompt
from district_llm.repair import RepairConfig, RepairReport, sanitize_action_payload
from district_llm.schema import DistrictAction, DistrictStateSummary
from district_llm.summary_builder import DistrictStateSummaryBuilder
from env.observation_builder import ObservationConfig
from env.reward import RewardConfig
from env.traffic_env import EnvConfig, TrafficEnv
from training.cityflow_dataset import CityFlowDataset
def _extract_json_object(payload: str) -> str:
start = payload.find("{")
end = payload.rfind("}")
if start == -1 or end == -1 or end <= start:
raise ValueError("No JSON object found in model output.")
return payload[start : end + 1]
@dataclass(frozen=True)
class DistrictLLMInferenceResult:
action: DistrictAction
raw_text: str
parsed_payload_before_repair: dict[str, Any]
repair_report: RepairReport
json_valid: bool
schema_valid_before_repair: bool
class DistrictLLMInference:
def __init__(
self,
generator_fn: Callable[[str], str] | None = None,
model_name_or_path: str | None = None,
device: str | None = None,
fallback_action: DistrictAction | None = None,
repair_config: RepairConfig | None = None,
):
self.fallback_action = fallback_action or DistrictAction.default_hold()
self.generator_fn = generator_fn
self.repair_config = repair_config or RepairConfig()
self.tokenizer = None
self.model = None
self.device = device or "cpu"
if self.generator_fn is None:
if not model_name_or_path:
raise ValueError("Provide either generator_fn or model_name_or_path.")
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
model_dir = Path(model_name_or_path)
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
if self.tokenizer.pad_token_id is None and self.tokenizer.eos_token_id is not None:
self.tokenizer.pad_token = self.tokenizer.eos_token
if (model_dir / "adapter_config.json").exists():
try:
from peft import AutoPeftModelForCausalLM
except ImportError as exc:
raise ImportError("Loading a LoRA adapter requires the 'peft' package.") from exc
self.model = AutoPeftModelForCausalLM.from_pretrained(model_name_or_path).to(self.device)
else:
self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path).to(self.device)
self.model.eval()
def generate_raw(self, prompt: str, max_new_tokens: int = 128) -> str:
if self.generator_fn is not None:
return self.generator_fn(prompt)
import torch
assert self.model is not None and self.tokenizer is not None
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
pad_token_id=self.tokenizer.eos_token_id,
)
generated = outputs[0][inputs["input_ids"].shape[1] :]
return self.tokenizer.decode(generated, skip_special_tokens=True)
def parse_action(
self,
payload: str,
summary: DistrictStateSummary | None = None,
) -> tuple[DistrictAction, RepairReport, dict[str, Any], bool, bool]:
json_valid = True
schema_valid_before_repair = True
try:
parsed_payload = json.loads(_extract_json_object(payload))
except Exception:
json_valid = False
schema_valid_before_repair = False
parsed_payload = self.fallback_action.to_dict()
action, repair_report = sanitize_action_payload(
payload=parsed_payload,
summary=summary,
config=self.repair_config,
)
return action, repair_report, parsed_payload, json_valid, schema_valid_before_repair
def predict_with_result(
self,
summary: DistrictStateSummary,
max_new_tokens: int = 128,
) -> DistrictLLMInferenceResult:
prompt = format_district_prompt(
summary,
max_target_intersections=self.repair_config.max_target_intersections,
allow_only_visible_candidates=self.repair_config.allow_only_visible_candidates,
)
raw = self.generate_raw(prompt=prompt, max_new_tokens=max_new_tokens)
action, repair_report, parsed_payload, json_valid, schema_valid_before_repair = self.parse_action(
raw,
summary=summary,
)
return DistrictLLMInferenceResult(
action=action,
raw_text=raw,
parsed_payload_before_repair=parsed_payload,
repair_report=repair_report,
json_valid=json_valid,
schema_valid_before_repair=schema_valid_before_repair,
)
def predict(self, summary: DistrictStateSummary, max_new_tokens: int = 128) -> DistrictAction:
return self.predict_with_result(summary=summary, max_new_tokens=max_new_tokens).action
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run single-sample district LLM inference.")
parser.add_argument("--model", required=True, help="Model name, local path, or LoRA adapter path.")
parser.add_argument("--generated-root", default="data/generated")
parser.add_argument("--splits-root", default="data/splits")
parser.add_argument("--city-id", required=True)
parser.add_argument("--scenario-name", required=True)
parser.add_argument("--district-id", required=True)
parser.add_argument("--device", default=None)
parser.add_argument("--max-new-tokens", type=int, default=128)
parser.add_argument(
"--allow-only-visible-candidates",
action=argparse.BooleanOptionalAction,
default=True,
)
parser.add_argument("--max-target-intersections", type=int, default=3)
parser.add_argument(
"--fallback-on-empty-targets",
action=argparse.BooleanOptionalAction,
default=True,
)
parser.add_argument(
"--fallback-mode",
choices=("heuristic", "hold", "none"),
default="heuristic",
)
return parser.parse_args()
def build_env(scenario_spec) -> TrafficEnv:
env_config = EnvConfig(
simulator_interval=1,
decision_interval=5,
min_green_time=10,
thread_num=1,
observation=ObservationConfig(),
reward=RewardConfig(variant="wait_queue_throughput"),
)
return TrafficEnv(
city_id=scenario_spec.city_id,
scenario_name=scenario_spec.scenario_name,
city_dir=scenario_spec.city_dir,
scenario_dir=scenario_spec.scenario_dir,
config_path=scenario_spec.config_path,
roadnet_path=scenario_spec.roadnet_path,
district_map_path=scenario_spec.district_map_path,
metadata_path=scenario_spec.metadata_path,
env_config=env_config,
)
def main() -> None:
args = parse_args()
dataset = CityFlowDataset(
generated_root=args.generated_root,
splits_root=args.splits_root,
)
scenario_spec = dataset.build_scenario_spec(args.city_id, args.scenario_name)
env = build_env(scenario_spec)
summary_builder = DistrictStateSummaryBuilder(candidate_limit=max(6, args.max_target_intersections))
observation_batch = env.reset()
summaries = summary_builder.build_all(env, observation_batch)
if args.district_id not in summaries:
raise ValueError(f"Unknown district_id '{args.district_id}' for {args.city_id}/{args.scenario_name}.")
inference = DistrictLLMInference(
model_name_or_path=args.model,
device=args.device,
fallback_action=DistrictAction.default_hold(),
repair_config=RepairConfig(
allow_only_visible_candidates=args.allow_only_visible_candidates,
max_target_intersections=args.max_target_intersections,
fallback_on_empty_targets=args.fallback_on_empty_targets,
fallback_mode=args.fallback_mode,
),
)
action = inference.predict(
summary=summaries[args.district_id],
max_new_tokens=args.max_new_tokens,
)
print(action.to_pretty_json())
if __name__ == "__main__":
main()