Spaces:
Sleeping
Sleeping
soupstick commited on
Commit ·
daa8b58
1
Parent(s): b5c15e3
feat: Implement new features and refactor core architecture
Browse files- configs/pairwise_judge.yaml.example +12 -0
- connectors/weaviate.py +27 -0
- requirements.txt +1 -3
- searchqual/cli.py +106 -3
- searchqual/connectors/basic.py +36 -0
- searchqual/core/config.py +22 -1
- searchqual/core/connectors.py +12 -0
- searchqual/core/run.yaml.j2 +40 -0
- searchqual/core/runner.py +4 -40
- searchqual/judgers/pairwise_llm_judge.py +99 -0
- src/streamlit_app.py +0 -40
- streamlit_app.py +20 -2
configs/pairwise_judge.yaml.example
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
models:
|
| 2 |
+
- provider: openai
|
| 3 |
+
model: gpt-4-turbo
|
| 4 |
+
params:
|
| 5 |
+
temperature: 0.1
|
| 6 |
+
|
| 7 |
+
rubric: |
|
| 8 |
+
You are a search quality evaluator. You will be given a query and two answers, A and B.
|
| 9 |
+
Your task is to determine which answer is better. A better answer is more relevant, accurate, and helpful.
|
| 10 |
+
Respond with a JSON object with two keys: "winner" and "reason".
|
| 11 |
+
The "winner" should be either "A" or "B".
|
| 12 |
+
The "reason" should be a brief explanation of why you chose that answer.
|
connectors/weaviate.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any, Dict
|
| 4 |
+
|
| 5 |
+
import aiohttp
|
| 6 |
+
import weaviate
|
| 7 |
+
from weaviate.auth import AuthApiKey
|
| 8 |
+
|
| 9 |
+
from ..core.config import WeaviateConnectorConfig
|
| 10 |
+
from ..core.dataset import Query
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class WeaviateConnector:
|
| 14 |
+
def __init__(self, cfg: WeaviateConnectorConfig) -> None:
|
| 15 |
+
self.cfg = cfg
|
| 16 |
+
auth_config = AuthApiKey(api_key=cfg.api_key) if cfg.api_key else None
|
| 17 |
+
self.client = weaviate.Client(url=cfg.url, auth_client_secret=auth_config)
|
| 18 |
+
|
| 19 |
+
async def search(self, session: aiohttp.ClientSession, query: Query) -> Dict[str, Any]:
|
| 20 |
+
del session
|
| 21 |
+
response = self.client.query.get(self.cfg.class_name, ["text", "score"]).with_near_text({"concepts": [query.text]}).do()
|
| 22 |
+
results = response.get("data", {}).get("Get", {}).get(self.cfg.class_name, [])
|
| 23 |
+
documents = [
|
| 24 |
+
{"doc_id": item.get("_additional", {}).get("id"), "text": item.get("text"), "score": item.get("_additional", {}).get("distance")}
|
| 25 |
+
for item in results
|
| 26 |
+
]
|
| 27 |
+
return {"documents": documents}
|
requirements.txt
CHANGED
|
@@ -1,3 +1 @@
|
|
| 1 |
-
|
| 2 |
-
pandas
|
| 3 |
-
streamlit
|
|
|
|
| 1 |
+
weaviate-client
|
|
|
|
|
|
searchqual/cli.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
|
| 2 |
from __future__ import annotations
|
| 3 |
|
| 4 |
import json
|
|
@@ -7,6 +7,9 @@ import re
|
|
| 7 |
from pathlib import Path
|
| 8 |
from typing import Dict, List, Optional
|
| 9 |
|
|
|
|
|
|
|
|
|
|
| 10 |
import typer
|
| 11 |
from rich.console import Console
|
| 12 |
from rich.table import Table
|
|
@@ -43,7 +46,21 @@ def run(
|
|
| 43 |
if not dataset_path.is_absolute():
|
| 44 |
dataset_path = (config.parent / dataset_path).resolve()
|
| 45 |
dataset_cfg = _load_dataset_config(dataset_path)
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
console.print('[bold]Running evaluation...[/bold]')
|
| 48 |
result = runner.run()
|
| 49 |
storage = RunStorage(Path(run_cfg.output_dir))
|
|
@@ -60,11 +77,59 @@ def run(
|
|
| 60 |
_render_metrics_table(result.metrics)
|
| 61 |
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
@app.command()
|
| 64 |
def compare(
|
| 65 |
baseline: Path = typer.Option(..., '--b', help='Baseline run JSON'),
|
| 66 |
candidate: Path = typer.Option(..., '--a', help='Candidate run JSON'),
|
| 67 |
-
|
| 68 |
) -> None:
|
| 69 |
"""Compare two runs and display metric deltas."""
|
| 70 |
summary = compare_runs(baseline, candidate)
|
|
@@ -83,6 +148,42 @@ def compare(
|
|
| 83 |
f"{info['relative_delta']*100:+.2f}%",
|
| 84 |
)
|
| 85 |
console.print(table)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
if gate:
|
| 87 |
ok, failures = _evaluate_gates(summary, gate)
|
| 88 |
if ok:
|
|
@@ -184,3 +285,5 @@ def _render_metrics_table(metrics: Dict[str, float]) -> None:
|
|
| 184 |
|
| 185 |
if __name__ == '__main__':
|
| 186 |
app()
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
from __future__ import annotations
|
| 3 |
|
| 4 |
import json
|
|
|
|
| 7 |
from pathlib import Path
|
| 8 |
from typing import Dict, List, Optional
|
| 9 |
|
| 10 |
+
import questionary
|
| 11 |
+
from jinja2 import Environment, FileSystemLoader
|
| 12 |
+
|
| 13 |
import typer
|
| 14 |
from rich.console import Console
|
| 15 |
from rich.table import Table
|
|
|
|
| 46 |
if not dataset_path.is_absolute():
|
| 47 |
dataset_path = (config.parent / dataset_path).resolve()
|
| 48 |
dataset_cfg = _load_dataset_config(dataset_path)
|
| 49 |
+
|
| 50 |
+
system = run_cfg.system
|
| 51 |
+
if system.type == 'http':
|
| 52 |
+
from .connectors.basic import HTTPSearchClient
|
| 53 |
+
connector = HTTPSearchClient(system)
|
| 54 |
+
elif system.type == 'local':
|
| 55 |
+
from .connectors.basic import LocalSearchClient
|
| 56 |
+
connector = LocalSearchClient(system)
|
| 57 |
+
elif system.type == 'weaviate':
|
| 58 |
+
from .connectors.weaviate import WeaviateConnector
|
| 59 |
+
connector = WeaviateConnector(system)
|
| 60 |
+
else:
|
| 61 |
+
raise TypeError(f'Unsupported system type: {system.type}')
|
| 62 |
+
|
| 63 |
+
runner = EvaluationRunner(run_cfg, dataset_cfg, connector, dataset_path=dataset_path)
|
| 64 |
console.print('[bold]Running evaluation...[/bold]')
|
| 65 |
result = runner.run()
|
| 66 |
storage = RunStorage(Path(run_cfg.output_dir))
|
|
|
|
| 77 |
_render_metrics_table(result.metrics)
|
| 78 |
|
| 79 |
|
| 80 |
+
|
| 81 |
+
@app.command()
|
| 82 |
+
def init():
|
| 83 |
+
"""Create a new run.yaml configuration file interactively."""
|
| 84 |
+
console.print("[bold green]Welcome to the SearchQual config wizard![/bold green]")
|
| 85 |
+
|
| 86 |
+
answers = {}
|
| 87 |
+
|
| 88 |
+
def _ensure(value):
|
| 89 |
+
if value is None:
|
| 90 |
+
console.print('[yellow]Wizard cancelled. No file written.[/yellow]')
|
| 91 |
+
raise typer.Exit(code=1)
|
| 92 |
+
return value
|
| 93 |
+
|
| 94 |
+
answers['name'] = _ensure(questionary.text("Enter a name for this evaluation run:", default="my-eval-run").ask())
|
| 95 |
+
answers['dataset_path'] = _ensure(questionary.text("Enter the path to your dataset.yaml:", default="configs/dataset.yaml.example").ask())
|
| 96 |
+
answers['system_url'] = _ensure(questionary.text("Enter the URL of the search system to evaluate:", default="http://localhost:8000/search").ask())
|
| 97 |
+
|
| 98 |
+
answers['retrieval_metrics'] = questionary.checkbox(
|
| 99 |
+
"Select retrieval metrics to track:",
|
| 100 |
+
choices=["ndcg@10", "recall@100", "mrr", "map@100", "precision@10"],
|
| 101 |
+
default=["ndcg@10", "recall@100"],
|
| 102 |
+
).ask() or []
|
| 103 |
+
|
| 104 |
+
answers['qa_metrics'] = questionary.checkbox(
|
| 105 |
+
"Select RAG/QA metrics to track:",
|
| 106 |
+
choices=["faithfulness", "citation_coverage", "nli_factuality", "helpfulness_likert"],
|
| 107 |
+
default=["faithfulness", "citation_coverage"],
|
| 108 |
+
).ask() or []
|
| 109 |
+
|
| 110 |
+
answers['use_judge'] = questionary.confirm("Use an LLM to judge answer quality?", default=True).ask()
|
| 111 |
+
if answers['use_judge']:
|
| 112 |
+
answers['judge_provider'] = _ensure(questionary.select(
|
| 113 |
+
"Select the LLM provider for judging:",
|
| 114 |
+
choices=["openai", "bedrock", "ollama"],
|
| 115 |
+
default="openai",
|
| 116 |
+
).ask())
|
| 117 |
+
answers['judge_model'] = _ensure(questionary.text("Enter the model name:", default="gpt-4-turbo").ask())
|
| 118 |
+
|
| 119 |
+
env = Environment(loader=FileSystemLoader(Path(__file__).parent / 'core'), trim_blocks=True, lstrip_blocks=True)
|
| 120 |
+
template = env.get_template('run.yaml.j2')
|
| 121 |
+
output_content = template.render(**answers)
|
| 122 |
+
|
| 123 |
+
output_path = Path('run.yaml')
|
| 124 |
+
output_path.write_text(output_content, encoding='utf-8')
|
| 125 |
+
|
| 126 |
+
console.print('\n[bold green]✅ Success![/bold green] Your run.yaml file has been created.')
|
| 127 |
+
console.print("You can now run an evaluation with: [cyan]sq run -c run.yaml[/cyan]")
|
| 128 |
@app.command()
|
| 129 |
def compare(
|
| 130 |
baseline: Path = typer.Option(..., '--b', help='Baseline run JSON'),
|
| 131 |
candidate: Path = typer.Option(..., '--a', help='Candidate run JSON'),
|
| 132 |
+
judge_config: Optional[Path] = typer.Option(None, '--judge', help='Path to pairwise judge config YAML'),
|
| 133 |
) -> None:
|
| 134 |
"""Compare two runs and display metric deltas."""
|
| 135 |
summary = compare_runs(baseline, candidate)
|
|
|
|
| 148 |
f"{info['relative_delta']*100:+.2f}%",
|
| 149 |
)
|
| 150 |
console.print(table)
|
| 151 |
+
|
| 152 |
+
if judge_config:
|
| 153 |
+
from .judgers.pairwise_llm_judge import PairwiseLLMJudge
|
| 154 |
+
from .core.runner import RunResult
|
| 155 |
+
|
| 156 |
+
console.print('\n[bold]Running pairwise evaluation...[/bold]')
|
| 157 |
+
judge_cfg = ConfigLoader.load_judge_config(judge_config)
|
| 158 |
+
judge = PairwiseLLMJudge(judge_cfg)
|
| 159 |
+
|
| 160 |
+
baseline_run = RunResult(**json.loads(baseline.read_text()))
|
| 161 |
+
candidate_run = RunResult(**json.loads(candidate.read_text()))
|
| 162 |
+
|
| 163 |
+
baseline_results = {r['query_id']: r for r in baseline_run.results}
|
| 164 |
+
candidate_results = {r['query_id']: r for r in candidate_run.results}
|
| 165 |
+
|
| 166 |
+
wins = 0
|
| 167 |
+
losses = 0
|
| 168 |
+
ties = 0
|
| 169 |
+
|
| 170 |
+
for query_id, baseline_result in baseline_results.items():
|
| 171 |
+
if query_id in candidate_results:
|
| 172 |
+
candidate_result = candidate_results[query_id]
|
| 173 |
+
judgment = asyncio.run(judge.evaluate(baseline_result['query'], baseline_result, candidate_result))
|
| 174 |
+
if judgment:
|
| 175 |
+
if judgment.winner == 'A':
|
| 176 |
+
losses += 1
|
| 177 |
+
elif judgment.winner == 'B':
|
| 178 |
+
wins += 1
|
| 179 |
+
else:
|
| 180 |
+
ties += 1
|
| 181 |
+
|
| 182 |
+
console.print(f'\n[bold]Pairwise Comparison Results[/bold]')
|
| 183 |
+
console.print(f'Candidate Wins: {wins}')
|
| 184 |
+
console.print(f'Candidate Losses: {losses}')
|
| 185 |
+
console.print(f'Ties: {ties}')
|
| 186 |
+
|
| 187 |
if gate:
|
| 188 |
ok, failures = _evaluate_gates(summary, gate)
|
| 189 |
if ok:
|
|
|
|
| 285 |
|
| 286 |
if __name__ == '__main__':
|
| 287 |
app()
|
| 288 |
+
|
| 289 |
+
|
searchqual/connectors/basic.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import importlib
|
| 5 |
+
from typing import Any, Dict
|
| 6 |
+
|
| 7 |
+
import aiohttp
|
| 8 |
+
|
| 9 |
+
from ..core.config import HTTPSystemConfig, LocalSystemConfig
|
| 10 |
+
from ..core.dataset import Query
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class HTTPSearchClient:
|
| 14 |
+
def __init__(self, cfg: HTTPSystemConfig) -> None:
|
| 15 |
+
self.cfg = cfg
|
| 16 |
+
|
| 17 |
+
async def search(self, session: aiohttp.ClientSession, query: Query) -> Dict[str, Any]:
|
| 18 |
+
payload = {**self.cfg.params, 'query': query.text}
|
| 19 |
+
async with session.post(self.cfg.url, json=payload, headers=self.cfg.headers) as resp:
|
| 20 |
+
resp.raise_for_status()
|
| 21 |
+
return await resp.json()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class LocalSearchClient:
|
| 25 |
+
def __init__(self, cfg: LocalSystemConfig) -> None:
|
| 26 |
+
module = importlib.import_module(cfg.module)
|
| 27 |
+
self.callable = getattr(module, cfg.object)
|
| 28 |
+
self.params = cfg.params
|
| 29 |
+
|
| 30 |
+
async def search(self, session: aiohttp.ClientSession, query: Query) -> Dict[str, Any]:
|
| 31 |
+
del session
|
| 32 |
+
loop = asyncio.get_running_loop()
|
| 33 |
+
return await loop.run_in_executor(None, self._call, query)
|
| 34 |
+
|
| 35 |
+
def _call(self, query: Query) -> Dict[str, Any]:
|
| 36 |
+
return self.callable(query.text, **self.params)
|
searchqual/core/config.py
CHANGED
|
@@ -46,7 +46,14 @@ class LocalSystemConfig(BaseModel):
|
|
| 46 |
params: Dict[str, Any] = Field(default_factory=dict)
|
| 47 |
|
| 48 |
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
|
| 52 |
class BudgetConfig(BaseModel):
|
|
@@ -76,6 +83,11 @@ class JudgeConfig(BaseModel):
|
|
| 76 |
majority_vote: bool = True
|
| 77 |
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
class RunConfig(BaseModel):
|
| 80 |
system: SystemConfig
|
| 81 |
dataset: str = Field(..., description="Path to dataset.yaml or dataset id")
|
|
@@ -125,6 +137,14 @@ class ConfigLoader:
|
|
| 125 |
except ValidationError as exc:
|
| 126 |
raise ConfigError(f'Invalid dataset config: {path}\n{exc}') from exc
|
| 127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
__all__ = [
|
| 130 |
'BudgetConfig',
|
|
@@ -136,6 +156,7 @@ __all__ = [
|
|
| 136 |
'JudgeModelConfig',
|
| 137 |
'LocalSystemConfig',
|
| 138 |
'MetricSelection',
|
|
|
|
| 139 |
'RunConfig',
|
| 140 |
'SystemConfig',
|
| 141 |
]
|
|
|
|
| 46 |
params: Dict[str, Any] = Field(default_factory=dict)
|
| 47 |
|
| 48 |
|
| 49 |
+
class WeaviateConnectorConfig(BaseModel):
|
| 50 |
+
type: Literal['weaviate'] = 'weaviate'
|
| 51 |
+
url: str
|
| 52 |
+
class_name: str
|
| 53 |
+
api_key: Optional[str] = None
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
SystemConfig = HTTPSystemConfig | LocalSystemConfig | WeaviateConnectorConfig
|
| 57 |
|
| 58 |
|
| 59 |
class BudgetConfig(BaseModel):
|
|
|
|
| 83 |
majority_vote: bool = True
|
| 84 |
|
| 85 |
|
| 86 |
+
class PairwiseJudgeConfig(BaseModel):
|
| 87 |
+
models: List[JudgeModelConfig] = Field(default_factory=list)
|
| 88 |
+
rubric: Optional[str] = None
|
| 89 |
+
|
| 90 |
+
|
| 91 |
class RunConfig(BaseModel):
|
| 92 |
system: SystemConfig
|
| 93 |
dataset: str = Field(..., description="Path to dataset.yaml or dataset id")
|
|
|
|
| 137 |
except ValidationError as exc:
|
| 138 |
raise ConfigError(f'Invalid dataset config: {path}\n{exc}') from exc
|
| 139 |
|
| 140 |
+
@staticmethod
|
| 141 |
+
def load_judge_config(path: Path) -> PairwiseJudgeConfig:
|
| 142 |
+
payload = ConfigLoader.read_yaml(path)
|
| 143 |
+
try:
|
| 144 |
+
return PairwiseJudgeConfig.model_validate(payload)
|
| 145 |
+
except ValidationError as exc:
|
| 146 |
+
raise ConfigError(f'Invalid judge config: {path}\n{exc}') from exc
|
| 147 |
+
|
| 148 |
|
| 149 |
__all__ = [
|
| 150 |
'BudgetConfig',
|
|
|
|
| 156 |
'JudgeModelConfig',
|
| 157 |
'LocalSystemConfig',
|
| 158 |
'MetricSelection',
|
| 159 |
+
'PairwiseJudgeConfig',
|
| 160 |
'RunConfig',
|
| 161 |
'SystemConfig',
|
| 162 |
]
|
searchqual/core/connectors.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any, Dict, Protocol
|
| 4 |
+
|
| 5 |
+
import aiohttp
|
| 6 |
+
|
| 7 |
+
from .dataset import Query
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SearchConnector(Protocol):
|
| 11 |
+
async def search(self, session: aiohttp.ClientSession, query: Query) -> Dict[str, Any]:
|
| 12 |
+
...
|
searchqual/core/run.yaml.j2
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Configuration for a SearchQual evaluation run, generated by 'sq init'
|
| 2 |
+
name: {{ name }}
|
| 3 |
+
dataset: {{ dataset_path }}
|
| 4 |
+
output_dir: runs/
|
| 5 |
+
|
| 6 |
+
system:
|
| 7 |
+
type: http
|
| 8 |
+
url: {{ system_url }}
|
| 9 |
+
# Example for other systems:
|
| 10 |
+
# type: weaviate
|
| 11 |
+
# url: "http://localhost:8080"
|
| 12 |
+
# class_name: "MyCollection"
|
| 13 |
+
# text_key: "content"
|
| 14 |
+
|
| 15 |
+
metrics:
|
| 16 |
+
retrieval:
|
| 17 |
+
{%- for metric in retrieval_metrics %}
|
| 18 |
+
- {{ metric }}
|
| 19 |
+
{%- endfor %}
|
| 20 |
+
qa:
|
| 21 |
+
{%- for metric in qa_metrics %}
|
| 22 |
+
- {{ metric }}
|
| 23 |
+
{%- endfor %}
|
| 24 |
+
|
| 25 |
+
{% if use_judge %}
|
| 26 |
+
judge:
|
| 27 |
+
rubric: configs/rubrics/qa_v1.md
|
| 28 |
+
models:
|
| 29 |
+
- provider: {{ judge_provider }}
|
| 30 |
+
model: {{ judge_model }}
|
| 31 |
+
sample_rate: 1.0
|
| 32 |
+
params:
|
| 33 |
+
# For OpenAI, add: api_key: ${OPENAI_API_KEY}
|
| 34 |
+
# For Bedrock, add: region: us-east-1
|
| 35 |
+
temperature: 0.1
|
| 36 |
+
{% endif %}
|
| 37 |
+
|
| 38 |
+
budget:
|
| 39 |
+
max_tokens: 100000
|
| 40 |
+
max_dollars: 5.00
|
searchqual/core/runner.py
CHANGED
|
@@ -70,54 +70,18 @@ class RunResult:
|
|
| 70 |
path.write_text(json.dumps(self.to_json(), indent=2), encoding='utf-8')
|
| 71 |
|
| 72 |
|
| 73 |
-
|
| 74 |
-
async def search(self, session: aiohttp.ClientSession, query: Query) -> Dict[str, Any]:
|
| 75 |
-
...
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
class HTTPSearchClient:
|
| 79 |
-
def __init__(self, cfg: HTTPSystemConfig) -> None:
|
| 80 |
-
self.cfg = cfg
|
| 81 |
-
|
| 82 |
-
async def search(self, session: aiohttp.ClientSession, query: Query) -> Dict[str, Any]:
|
| 83 |
-
payload = {**self.cfg.params, 'query': query.text}
|
| 84 |
-
async with session.post(self.cfg.url, json=payload, headers=self.cfg.headers) as resp:
|
| 85 |
-
resp.raise_for_status()
|
| 86 |
-
return await resp.json()
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
class LocalSearchClient:
|
| 90 |
-
def __init__(self, cfg: LocalSystemConfig) -> None:
|
| 91 |
-
module = importlib.import_module(cfg.module)
|
| 92 |
-
self.callable = getattr(module, cfg.object)
|
| 93 |
-
self.params = cfg.params
|
| 94 |
-
|
| 95 |
-
async def search(self, session: aiohttp.ClientSession, query: Query) -> Dict[str, Any]:
|
| 96 |
-
del session
|
| 97 |
-
loop = asyncio.get_running_loop()
|
| 98 |
-
return await loop.run_in_executor(None, self._call, query)
|
| 99 |
-
|
| 100 |
-
def _call(self, query: Query) -> Dict[str, Any]:
|
| 101 |
-
return self.callable(query.text, **self.params)
|
| 102 |
|
| 103 |
|
| 104 |
class EvaluationRunner:
|
| 105 |
-
def __init__(self, run_cfg: RunConfig, dataset_cfg: DatasetConfig, dataset_path: Optional[Path] = None) -> None:
|
| 106 |
self.run_cfg = run_cfg
|
| 107 |
self.dataset_cfg = dataset_cfg
|
|
|
|
| 108 |
self.dataset_path = dataset_path
|
| 109 |
|
| 110 |
-
def _build_client(self) -> SearchClient:
|
| 111 |
-
system = self.run_cfg.system
|
| 112 |
-
if isinstance(system, HTTPSystemConfig):
|
| 113 |
-
return HTTPSearchClient(system)
|
| 114 |
-
if isinstance(system, LocalSystemConfig):
|
| 115 |
-
return LocalSearchClient(system)
|
| 116 |
-
raise TypeError(f'Unsupported system config: {system}')
|
| 117 |
-
|
| 118 |
async def run_async(self) -> RunResult:
|
| 119 |
dataset = self._load_dataset()
|
| 120 |
-
client = self._build_client()
|
| 121 |
connector = aiohttp.TCPConnector(limit=self.run_cfg.system.concurrency if isinstance(self.run_cfg.system, HTTPSystemConfig) else 10)
|
| 122 |
timeout = aiohttp.ClientTimeout(total=getattr(self.run_cfg.system, 'timeout', 30.0))
|
| 123 |
sem = asyncio.Semaphore(getattr(self.run_cfg.system, 'concurrency', 5))
|
|
@@ -129,7 +93,7 @@ class EvaluationRunner:
|
|
| 129 |
async def evaluate(query: Query) -> QueryResult:
|
| 130 |
async with sem:
|
| 131 |
start = perf_counter()
|
| 132 |
-
response = await
|
| 133 |
latency_ms = (perf_counter() - start) * 1000
|
| 134 |
docs = []
|
| 135 |
for item in response.get('documents', response.get('results', [])):
|
|
|
|
| 70 |
path.write_text(json.dumps(self.to_json(), indent=2), encoding='utf-8')
|
| 71 |
|
| 72 |
|
| 73 |
+
from .connectors import SearchConnector
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
|
| 76 |
class EvaluationRunner:
|
| 77 |
+
def __init__(self, run_cfg: RunConfig, dataset_cfg: DatasetConfig, connector: SearchConnector, dataset_path: Optional[Path] = None) -> None:
|
| 78 |
self.run_cfg = run_cfg
|
| 79 |
self.dataset_cfg = dataset_cfg
|
| 80 |
+
self.connector = connector
|
| 81 |
self.dataset_path = dataset_path
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
async def run_async(self) -> RunResult:
|
| 84 |
dataset = self._load_dataset()
|
|
|
|
| 85 |
connector = aiohttp.TCPConnector(limit=self.run_cfg.system.concurrency if isinstance(self.run_cfg.system, HTTPSystemConfig) else 10)
|
| 86 |
timeout = aiohttp.ClientTimeout(total=getattr(self.run_cfg.system, 'timeout', 30.0))
|
| 87 |
sem = asyncio.Semaphore(getattr(self.run_cfg.system, 'concurrency', 5))
|
|
|
|
| 93 |
async def evaluate(query: Query) -> QueryResult:
|
| 94 |
async with sem:
|
| 95 |
start = perf_counter()
|
| 96 |
+
response = await self.connector.search(session, query)
|
| 97 |
latency_ms = (perf_counter() - start) * 1000
|
| 98 |
docs = []
|
| 99 |
for item in response.get('documents', response.get('results', [])):
|
searchqual/judgers/pairwise_llm_judge.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any, Dict, List, Optional
|
| 9 |
+
|
| 10 |
+
import aiohttp
|
| 11 |
+
|
| 12 |
+
from ..core.config import JudgeConfig, JudgeModelConfig
|
| 13 |
+
from ..core.dataset import Query
|
| 14 |
+
from ..core.runner import QueryResult
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
import boto3
|
| 18 |
+
except ImportError: # pragma: no cover - optional dependency
|
| 19 |
+
boto3 = None
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class PairwiseJudgment:
|
| 24 |
+
winner: str
|
| 25 |
+
reason: str
|
| 26 |
+
raw: Dict[str, Any]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class PairwiseLLMJudge:
|
| 30 |
+
def __init__(self, cfg: JudgeConfig) -> None:
|
| 31 |
+
self.cfg = cfg
|
| 32 |
+
self.rubric_text = self._load_rubric(cfg.rubric) if cfg.rubric else None
|
| 33 |
+
|
| 34 |
+
async def evaluate(self, query: Query, result_a: QueryResult, result_b: QueryResult) -> Optional[PairwiseJudgment]:
|
| 35 |
+
if not self.cfg.models:
|
| 36 |
+
return None
|
| 37 |
+
|
| 38 |
+
# For now, we only use the first model for pairwise comparison
|
| 39 |
+
model_cfg = self.cfg.models[0]
|
| 40 |
+
|
| 41 |
+
async with aiohttp.ClientSession() as session:
|
| 42 |
+
return await self._call_model(session, model_cfg, query, result_a, result_b)
|
| 43 |
+
|
| 44 |
+
async def _call_model(self, session: aiohttp.ClientSession, model_cfg: JudgeModelConfig, query: Query, result_a: QueryResult, result_b: QueryResult) -> Optional[PairwiseJudgment]:
|
| 45 |
+
prompt = self._build_prompt(query, result_a, result_b)
|
| 46 |
+
if model_cfg.provider == 'openai':
|
| 47 |
+
url = model_cfg.params.get('url', 'https://api.openai.com/v1/chat/completions')
|
| 48 |
+
headers = {
|
| 49 |
+
'Authorization': f"Bearer {model_cfg.params.get('api_key', os.environ.get('OPENAI_API_KEY', ''))}",
|
| 50 |
+
'Content-Type': 'application/json',
|
| 51 |
+
}
|
| 52 |
+
payload = {
|
| 53 |
+
'model': model_cfg.model,
|
| 54 |
+
'messages': [
|
| 55 |
+
{'role': 'system', 'content': self.rubric_text or 'You are a strict search-quality judge.'},
|
| 56 |
+
{'role': 'user', 'content': prompt},
|
| 57 |
+
],
|
| 58 |
+
'temperature': model_cfg.params.get('temperature', 0.1),
|
| 59 |
+
}
|
| 60 |
+
async with session.post(url, json=payload, timeout=aiohttp.ClientTimeout(total=60)) as resp:
|
| 61 |
+
resp.raise_for_status()
|
| 62 |
+
data = await resp.json()
|
| 63 |
+
content = data['choices'][0]['message']['content']
|
| 64 |
+
return self._parse_response(content, model_cfg)
|
| 65 |
+
raise NotImplementedError(f'Provider {model_cfg.provider} not implemented')
|
| 66 |
+
|
| 67 |
+
def _parse_response(self, text: str, model_cfg: JudgeModelConfig) -> Optional[PairwiseJudgment]:
|
| 68 |
+
try:
|
| 69 |
+
data = json.loads(text)
|
| 70 |
+
winner = data.get('winner')
|
| 71 |
+
reason = data.get('reason')
|
| 72 |
+
if winner in ['A', 'B']:
|
| 73 |
+
return PairwiseJudgment(winner=winner, reason=reason, raw=data)
|
| 74 |
+
except json.JSONDecodeError:
|
| 75 |
+
pass
|
| 76 |
+
return None
|
| 77 |
+
|
| 78 |
+
def _build_prompt(self, query: Query, result_a: QueryResult, result_b: QueryResult) -> str:
|
| 79 |
+
rubric = self.rubric_text or 'Compare the two answers and determine which one is better.'
|
| 80 |
+
prompt_lines = [
|
| 81 |
+
rubric,
|
| 82 |
+
'',
|
| 83 |
+
f'Query: {query.text}',
|
| 84 |
+
'',
|
| 85 |
+
'Answer A:',
|
| 86 |
+
result_a.answer,
|
| 87 |
+
'',
|
| 88 |
+
'Answer B:',
|
| 89 |
+
result_b.answer,
|
| 90 |
+
'',
|
| 91 |
+
'Which answer is better? Respond with JSON with keys "winner" (either "A" or "B") and "reason".',
|
| 92 |
+
]
|
| 93 |
+
return '\n'.join(line for line in prompt_lines if line)
|
| 94 |
+
|
| 95 |
+
def _load_rubric(self, path: str) -> str:
|
| 96 |
+
return Path(path).read_text(encoding='utf-8')
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
__all__ = ['PairwiseLLMJudge']
|
src/streamlit_app.py
DELETED
|
@@ -1,40 +0,0 @@
|
|
| 1 |
-
import altair as alt
|
| 2 |
-
import numpy as np
|
| 3 |
-
import pandas as pd
|
| 4 |
-
import streamlit as st
|
| 5 |
-
|
| 6 |
-
"""
|
| 7 |
-
# Welcome to Streamlit!
|
| 8 |
-
|
| 9 |
-
Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
|
| 10 |
-
If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
|
| 11 |
-
forums](https://discuss.streamlit.io).
|
| 12 |
-
|
| 13 |
-
In the meantime, below is an example of what you can do with just a few lines of code:
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
|
| 17 |
-
num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
|
| 18 |
-
|
| 19 |
-
indices = np.linspace(0, 1, num_points)
|
| 20 |
-
theta = 2 * np.pi * num_turns * indices
|
| 21 |
-
radius = indices
|
| 22 |
-
|
| 23 |
-
x = radius * np.cos(theta)
|
| 24 |
-
y = radius * np.sin(theta)
|
| 25 |
-
|
| 26 |
-
df = pd.DataFrame({
|
| 27 |
-
"x": x,
|
| 28 |
-
"y": y,
|
| 29 |
-
"idx": indices,
|
| 30 |
-
"rand": np.random.randn(num_points),
|
| 31 |
-
})
|
| 32 |
-
|
| 33 |
-
st.altair_chart(alt.Chart(df, height=700, width=700)
|
| 34 |
-
.mark_point(filled=True)
|
| 35 |
-
.encode(
|
| 36 |
-
x=alt.X("x", axis=None),
|
| 37 |
-
y=alt.Y("y", axis=None),
|
| 38 |
-
color=alt.Color("idx", legend=None, scale=alt.Scale()),
|
| 39 |
-
size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
|
| 40 |
-
))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
streamlit_app.py
CHANGED
|
@@ -42,8 +42,26 @@ with col1:
|
|
| 42 |
selected = st.selectbox('Select run', run_files, format_func=lambda p: p.stem)
|
| 43 |
if st.button('Load run'):
|
| 44 |
data = json.loads(selected.read_text())
|
| 45 |
-
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
else:
|
| 48 |
st.info('No runs found yet. Execute a run to populate metrics.')
|
| 49 |
|
|
|
|
| 42 |
selected = st.selectbox('Select run', run_files, format_func=lambda p: p.stem)
|
| 43 |
if st.button('Load run'):
|
| 44 |
data = json.loads(selected.read_text())
|
| 45 |
+
metrics = data.get('metrics', {})
|
| 46 |
+
if metrics:
|
| 47 |
+
st.subheader("Metrics")
|
| 48 |
+
cols = st.columns(len(metrics))
|
| 49 |
+
for i, (k, v) in enumerate(metrics.items()):
|
| 50 |
+
cols[i].metric(k, v)
|
| 51 |
+
|
| 52 |
+
stats = data.get('stats', {})
|
| 53 |
+
if stats:
|
| 54 |
+
st.subheader("Stats")
|
| 55 |
+
st.table(stats)
|
| 56 |
+
|
| 57 |
+
results = data.get('results', [])
|
| 58 |
+
if results:
|
| 59 |
+
st.subheader("Results")
|
| 60 |
+
for res in results:
|
| 61 |
+
st.write(f"**Query:** {res['query']}")
|
| 62 |
+
if res.get('answer'):
|
| 63 |
+
st.write(f"**Answer:** {res['answer']}")
|
| 64 |
+
st.dataframe(res.get('documents', []))
|
| 65 |
else:
|
| 66 |
st.info('No runs found yet. Execute a run to populate metrics.')
|
| 67 |
|