ashish-sarvam's picture
Upload folder using huggingface_hub
fc1a684 verified
from __future__ import annotations
import json
import random
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Dict, List, Optional
from conv_data_gen.config import config
from conv_data_gen.llm import LLMClient
from conv_data_gen.logger import setup_logger
logger = setup_logger(__name__)
class ToolEvolver:
"""Encapsulates tool evolution with weighted randomization."""
def __init__(
self,
llm_client: Optional[LLMClient] = None,
evolution_system_prompt: str = "",
complexity_weights: Optional[Dict[str, float]] = None,
type_weights: Optional[Dict[str, float]] = None,
max_workers: int = 4,
) -> None:
self.client = llm_client or LLMClient()
self.system_prompt = evolution_system_prompt
# Load knobs from YAML (mandatory); allow explicit args to override
self.complexity_weights, self.type_weights = self._load_knobs()
if isinstance(complexity_weights, dict):
self.complexity_weights.update(complexity_weights)
if isinstance(type_weights, dict):
self.type_weights.update(type_weights)
self.max_workers = max(1, max_workers)
@staticmethod
def _load_yaml(path_str: str) -> Optional[Dict[str, Any]]:
try:
import yaml # type: ignore[import-untyped]
except Exception:
return None
try:
from pathlib import Path
p = Path(path_str)
if not p.exists():
return None
with open(p, "r", encoding="utf-8") as f:
data = yaml.safe_load(f) # type: ignore[no-any-return]
if isinstance(data, dict):
return data
return None
except Exception:
return None
def _load_knobs(self) -> tuple[Dict[str, float], Dict[str, float]]:
path_str = str(config.paths.TOOL_EVOLUTION_KNOBS_YAML)
yaml_obj = self._load_yaml(path_str)
if not isinstance(yaml_obj, dict):
raise FileNotFoundError(
f"Evolution knobs YAML missing or invalid: {path_str}"
)
cw = yaml_obj.get("complexity_weights")
tw = yaml_obj.get("type_weights")
if not isinstance(cw, dict) or not isinstance(tw, dict):
raise ValueError(
"Evolution knobs YAML must define 'complexity_weights' and "
"'type_weights'"
)
out_c: Dict[str, float] = {}
out_t: Dict[str, float] = {}
# Strictly coerce to floats; raise on invalid entries
for k, v in cw.items():
try:
out_c[str(k)] = float(v)
except Exception as exc:
raise ValueError(
f"Invalid complexity weight for '{k}': {v}"
) from exc
for k, v in tw.items():
try:
out_t[str(k)] = float(v)
except Exception as exc:
raise ValueError(
f"Invalid type weight for '{k}': {v}"
) from exc
if not out_c or not out_t:
raise ValueError("Evolution knobs cannot be empty")
return out_c, out_t
@staticmethod
def _weighted_choice(options: Dict[str, float]) -> str:
items = list(options.items())
total = sum(max(0.0, float(w)) for _, w in items) or 1.0
r = random.random() * total
upto = 0.0
for key, weight in items:
w = max(0.0, float(weight))
if upto + w >= r:
return key
upto += w
return items[-1][0]
@staticmethod
def _ensure_complexity_level(tool: Dict[str, Any]) -> None:
if isinstance(tool.get("complexity_level"), str):
return
args = tool.get("function_args") or {}
if not isinstance(args, dict):
tool["complexity_level"] = "low"
return
num_args = len(args)
if num_args <= 2:
tool["complexity_level"] = "low"
elif num_args <= 4:
tool["complexity_level"] = "mid"
else:
tool["complexity_level"] = "high"
def _evolve_one(
self, context: Dict[str, str], tool: Dict[str, Any]
) -> Dict[str, Any]:
try:
comp_pick = self._weighted_choice(self.complexity_weights)
if comp_pick == "none":
self._ensure_complexity_level(tool)
return tool
type_pick = self._weighted_choice(self.type_weights)
if type_pick == "args_only":
instruction = (
f"make it {comp_pick} difficulty, change only input args"
)
else:
instruction = (
f"make it {comp_pick} difficulty, change both args "
f"and response"
)
payload = {
"company": context.get("company", ""),
"agent_type": context.get("agent_type", ""),
"use_case": context.get("use_case", ""),
"tool": tool,
"instruction": instruction,
}
user_prompt = json.dumps(payload, ensure_ascii=False)
resp = self.client.get_llm_response_json(
messages=[{"role": "user", "content": user_prompt}],
model=config.models.TOOL_MODEL,
system_prompt=self.system_prompt,
max_tokens=config.models.TOOL_MODEL_MAX_TOKENS,
)
parsed = self.client.safe_parse_json(resp.get("text", ""))
if not isinstance(parsed, dict):
self._ensure_complexity_level(tool)
return tool
new_tool = parsed.get("tool")
if isinstance(new_tool, dict):
self._ensure_complexity_level(new_tool)
return new_tool
self._ensure_complexity_level(tool)
return tool
except Exception as e:
logger.warning("[ToolEvolver] evolve failed: %s", e)
self._ensure_complexity_level(tool)
return tool
def evolve_row_tools(self, row: Dict[str, Any]) -> Dict[str, Any]:
context = {
"company": str(row.get("company", "")),
"agent_type": str(row.get("agent_type", "")),
"use_case": str(row.get("use_case", "")),
}
tools = row.get("tools")
if not isinstance(tools, list) or not tools:
return row
results: List[Dict[str, Any]] = []
with ThreadPoolExecutor(max_workers=self.max_workers) as ex:
futures = [
ex.submit(self._evolve_one, context, t)
for t in tools
if isinstance(t, dict)
]
for fu in as_completed(futures):
try:
results.append(fu.result())
except Exception:
results.append({})
ordered: List[Dict[str, Any]] = []
idx = 0
for t in tools:
if isinstance(t, dict):
ordered.append(results[idx])
idx += 1
else:
ordered.append(t)
row["tools"] = ordered
return row
def evolve_rows(
self, rows: List[Dict[str, Any]], outer_workers: int
) -> List[Dict[str, Any]]:
if not rows:
return rows
out: List[Dict[str, Any]] = []
with ThreadPoolExecutor(max_workers=max(1, outer_workers)) as ex:
futs = [ex.submit(self.evolve_row_tools, r) for r in rows]
for fu in as_completed(futs):
try:
out.append(fu.result())
except Exception:
pass
return out or rows