iris-at-text2sparql / src /config.py
Alex Latipov
Harden frozen eval prompts and judge JSON handling
d745844
"""Configuration loading for the Text2SPARQL repair pipeline."""
from __future__ import annotations
import yaml
from pathlib import Path
from pydantic import BaseModel, Field
class RuntimeConfig(BaseModel):
"""Runtime configuration for the pipeline."""
# LLM backend: "vllm" for local models, "openai" for API
llm_backend: str = "vllm"
# LLM model names
generation_model: str = "QuantTrio/Qwen3.5-27B-AWQ"
committee_model: str = "QuantTrio/Qwen3.5-27B-AWQ"
repair_model: str = "QuantTrio/Qwen3.5-27B-AWQ"
# LLM temperatures
temperature_generation: float = 0.7
temperature_committee: float = 0.2
temperature_repair: float = 0.3
# vLLM-specific settings
gpu_memory_utilization: float = 0.85
max_model_len: int = 8192
enforce_eager: bool = True
max_tokens: int = 4096
# Pipeline parameters
k_candidates: int = 1
max_repair_iterations: int = 3
max_syntax_attempts: int | None = None
max_semantic_attempts: int | None = None
# Context parameters
entity_top_k: int = 5
relation_top_k: int = 8
class_top_k: int = 5
linker_mode: str = "internal_min"
judge_mode: str = "committee_multi"
dbpedia_lookup_url: str = "https://lookup.dbpedia.org/api/search"
dbpedia_spotlight_base_url: str = "https://api.dbpedia-spotlight.org"
embedding_model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
embedding_candidate_pool_size: int = 16
# Prompt-bank overrides (key -> relative or absolute file path)
prompt_files: dict[str, str] = Field(default_factory=dict)
# Validation parameters
request_timeout_sec: int = 30
# Selection scoring weights
selection_weights: dict[str, float] = Field(default_factory=lambda: {
"parse_ok": 5.0,
"execute_ok": 5.0,
"answer_type_fit": 2.0,
"schema_fit": 2.0,
"timeout": -2.0,
"empty_result": -1.5,
"huge_result": -1.0,
"suspicious_flag": -0.5,
})
# Dataset configs (dataset_id -> dict of settings)
datasets: dict[str, dict] = Field(default_factory=dict)
@property
def syntax_attempt_limit(self) -> int:
return self.max_syntax_attempts or self.max_repair_iterations
@property
def semantic_attempt_limit(self) -> int:
return self.max_semantic_attempts or self.max_repair_iterations
def _load_yaml(path: str | Path) -> dict:
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"Config file not found: {path}")
with open(path, "r", encoding="utf-8") as f:
return yaml.safe_load(f) or {}
def _deep_merge(base: dict, override: dict) -> dict:
merged = dict(base)
for key, value in override.items():
if (
key in merged
and isinstance(merged[key], dict)
and isinstance(value, dict)
):
merged[key] = _deep_merge(merged[key], value)
else:
merged[key] = value
return merged
def load_config(path: str, base_path: str | None = None) -> RuntimeConfig:
"""Load a YAML config file and return a merged RuntimeConfig."""
raw = _load_yaml(path)
if base_path:
raw = _deep_merge(_load_yaml(base_path), raw)
return RuntimeConfig(**raw)