from __future__ import annotations import json import random from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Any, Dict, List, Optional from conv_data_gen.config import config from conv_data_gen.llm import LLMClient from conv_data_gen.logger import setup_logger logger = setup_logger(__name__) class ToolEvolver: """Encapsulates tool evolution with weighted randomization.""" def __init__( self, llm_client: Optional[LLMClient] = None, evolution_system_prompt: str = "", complexity_weights: Optional[Dict[str, float]] = None, type_weights: Optional[Dict[str, float]] = None, max_workers: int = 4, ) -> None: self.client = llm_client or LLMClient() self.system_prompt = evolution_system_prompt # Load knobs from YAML (mandatory); allow explicit args to override self.complexity_weights, self.type_weights = self._load_knobs() if isinstance(complexity_weights, dict): self.complexity_weights.update(complexity_weights) if isinstance(type_weights, dict): self.type_weights.update(type_weights) self.max_workers = max(1, max_workers) @staticmethod def _load_yaml(path_str: str) -> Optional[Dict[str, Any]]: try: import yaml # type: ignore[import-untyped] except Exception: return None try: from pathlib import Path p = Path(path_str) if not p.exists(): return None with open(p, "r", encoding="utf-8") as f: data = yaml.safe_load(f) # type: ignore[no-any-return] if isinstance(data, dict): return data return None except Exception: return None def _load_knobs(self) -> tuple[Dict[str, float], Dict[str, float]]: path_str = str(config.paths.TOOL_EVOLUTION_KNOBS_YAML) yaml_obj = self._load_yaml(path_str) if not isinstance(yaml_obj, dict): raise FileNotFoundError( f"Evolution knobs YAML missing or invalid: {path_str}" ) cw = yaml_obj.get("complexity_weights") tw = yaml_obj.get("type_weights") if not isinstance(cw, dict) or not isinstance(tw, dict): raise ValueError( "Evolution knobs YAML must define 'complexity_weights' and " "'type_weights'" ) out_c: Dict[str, float] = {} out_t: Dict[str, float] = {} # Strictly coerce to floats; raise on invalid entries for k, v in cw.items(): try: out_c[str(k)] = float(v) except Exception as exc: raise ValueError( f"Invalid complexity weight for '{k}': {v}" ) from exc for k, v in tw.items(): try: out_t[str(k)] = float(v) except Exception as exc: raise ValueError( f"Invalid type weight for '{k}': {v}" ) from exc if not out_c or not out_t: raise ValueError("Evolution knobs cannot be empty") return out_c, out_t @staticmethod def _weighted_choice(options: Dict[str, float]) -> str: items = list(options.items()) total = sum(max(0.0, float(w)) for _, w in items) or 1.0 r = random.random() * total upto = 0.0 for key, weight in items: w = max(0.0, float(weight)) if upto + w >= r: return key upto += w return items[-1][0] @staticmethod def _ensure_complexity_level(tool: Dict[str, Any]) -> None: if isinstance(tool.get("complexity_level"), str): return args = tool.get("function_args") or {} if not isinstance(args, dict): tool["complexity_level"] = "low" return num_args = len(args) if num_args <= 2: tool["complexity_level"] = "low" elif num_args <= 4: tool["complexity_level"] = "mid" else: tool["complexity_level"] = "high" def _evolve_one( self, context: Dict[str, str], tool: Dict[str, Any] ) -> Dict[str, Any]: try: comp_pick = self._weighted_choice(self.complexity_weights) if comp_pick == "none": self._ensure_complexity_level(tool) return tool type_pick = self._weighted_choice(self.type_weights) if type_pick == "args_only": instruction = ( f"make it {comp_pick} difficulty, change only input args" ) else: instruction = ( f"make it {comp_pick} difficulty, change both args " f"and response" ) payload = { "company": context.get("company", ""), "agent_type": context.get("agent_type", ""), "use_case": context.get("use_case", ""), "tool": tool, "instruction": instruction, } user_prompt = json.dumps(payload, ensure_ascii=False) resp = self.client.get_llm_response_json( messages=[{"role": "user", "content": user_prompt}], model=config.models.TOOL_MODEL, system_prompt=self.system_prompt, max_tokens=config.models.TOOL_MODEL_MAX_TOKENS, ) parsed = self.client.safe_parse_json(resp.get("text", "")) if not isinstance(parsed, dict): self._ensure_complexity_level(tool) return tool new_tool = parsed.get("tool") if isinstance(new_tool, dict): self._ensure_complexity_level(new_tool) return new_tool self._ensure_complexity_level(tool) return tool except Exception as e: logger.warning("[ToolEvolver] evolve failed: %s", e) self._ensure_complexity_level(tool) return tool def evolve_row_tools(self, row: Dict[str, Any]) -> Dict[str, Any]: context = { "company": str(row.get("company", "")), "agent_type": str(row.get("agent_type", "")), "use_case": str(row.get("use_case", "")), } tools = row.get("tools") if not isinstance(tools, list) or not tools: return row results: List[Dict[str, Any]] = [] with ThreadPoolExecutor(max_workers=self.max_workers) as ex: futures = [ ex.submit(self._evolve_one, context, t) for t in tools if isinstance(t, dict) ] for fu in as_completed(futures): try: results.append(fu.result()) except Exception: results.append({}) ordered: List[Dict[str, Any]] = [] idx = 0 for t in tools: if isinstance(t, dict): ordered.append(results[idx]) idx += 1 else: ordered.append(t) row["tools"] = ordered return row def evolve_rows( self, rows: List[Dict[str, Any]], outer_workers: int ) -> List[Dict[str, Any]]: if not rows: return rows out: List[Dict[str, Any]] = [] with ThreadPoolExecutor(max_workers=max(1, outer_workers)) as ex: futs = [ex.submit(self.evolve_row_tools, r) for r in rows] for fu in as_completed(futs): try: out.append(fu.result()) except Exception: pass return out or rows