""" Task Bank — LLM-generated tasks derived from the specialist catalog. Tasks are generated dynamically using GPT-4o-mini based on: 1. The sector defined in training_config.yaml 2. The specialist roster in specialist_catalog.yaml 3. The current curriculum phase (controls complexity) No hardcoded task lists. Any sector works by swapping the catalog + sector config. """ from __future__ import annotations import random import threading import yaml import os from pathlib import Path from dataclasses import dataclass from typing import Optional def _load_complexity_config(config_path: str) -> tuple[dict, dict]: """Load COMPLEXITY_BY_PHASE and COMPLEXITY_DESCRIPTIONS from config files.""" import os base = os.path.dirname(os.path.abspath(config_path)) with open(config_path) as f: cfg = yaml.safe_load(f) cur = cfg.get("curriculum", {}) by_phase = { 1: cur.get("phase1_task_types", ["atomic", "simple"]), 2: cur.get("phase2_task_types", ["moderate"]), 3: cur.get("phase3_task_types", ["complex", "enterprise"]), } desc_path = os.path.join(base, "complexity_descriptions.yaml") try: with open(desc_path) as f: descriptions = yaml.safe_load(f) except FileNotFoundError: descriptions = { "atomic": "a very simple, single-step", "simple": "a straightforward, well-scoped", "moderate": "a multi-component, realistic", "complex": "a complex, multi-system", "enterprise": "a large-scale, enterprise-grade", } return by_phase, descriptions @dataclass class Task: description: str complexity_class: str domain: str class TaskBank: """ Generates tasks dynamically using GPT-4o-mini. Falls back to catalog-derived tasks if OpenAI is unavailable. Tasks are pre-cached in batches to avoid per-episode API latency. """ def __init__( self, phase: int = 1, config_path: str = "configs/training_config.yaml", catalog_path: str = "configs/specialist_catalog.yaml", ): self.phase = phase self._cache: list[Task] = [] self._client = None self._cache_lock = threading.Lock() self._refill_running = False # Load complexity config from yaml files (not hardcoded) self._complexity_by_phase, self._complexity_descriptions = ( _load_complexity_config(config_path) ) # Load sector config with open(config_path) as f: cfg = yaml.safe_load(f) sector_cfg = cfg.get("sector", {}) self.sector_name = sector_cfg.get("name", "software_engineering") self.sector_description = sector_cfg.get( "description", "Software product development" ) self.use_llm = sector_cfg.get("use_llm_task_generation", True) self.llm_model = sector_cfg.get("llm_task_model", "gpt-4o-mini") self.cache_size = sector_cfg.get("task_cache_size", 50) # Load specialist roles from catalog (for context in prompts) with open(catalog_path) as f: catalog = yaml.safe_load(f) self._specialist_roles = [ s["role"] for s in catalog.get("specialists", []) ] if self.use_llm: self._init_openai() # Pre-fill cache self._refill_cache() def _init_openai(self): try: from openai import OpenAI self._client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) except Exception as e: print(f"[TaskBank] OpenAI unavailable: {e}. Using catalog-derived tasks.") self._client = None def _refill_cache(self): """ Synchronously generate a batch of tasks and extend the cache. Thread-safe: holds _cache_lock while writing; clears _refill_running on exit. Called directly on first fill (init) and from the background thread thereafter. """ complexities = self._complexity_by_phase.get(self.phase, ["simple"]) n_per_complexity = max(1, self.cache_size // len(complexities)) new_tasks: list[Task] = [] for complexity in complexities: if self._client and self.use_llm: batch = self._generate_llm_tasks(complexity, n_per_complexity) else: batch = self._generate_catalog_tasks(complexity, n_per_complexity) new_tasks.extend(batch) random.shuffle(new_tasks) with self._cache_lock: self._cache.extend(new_tasks) self._refill_running = False def _refill_cache_background(self): """Trigger a non-blocking background refill if one isn't already running.""" with self._cache_lock: if self._refill_running: return # already in flight — don't pile up threads self._refill_running = True t = threading.Thread(target=self._refill_cache, daemon=True) t.start() def _generate_llm_tasks(self, complexity: str, n: int) -> list[Task]: """Generate n tasks of the given complexity using GPT-4o-mini. Batches requests at max 20 tasks per API call to avoid JSON truncation from max_tokens limits. Results are concatenated into a single list. """ complexity_desc = self._complexity_descriptions.get(complexity, "a realistic") roles_str = ", ".join(self._specialist_roles) batch_size = 20 # safe upper bound — 20 tasks × ~40 tokens each ≈ 800 tokens all_tasks: list[Task] = [] for batch_start in range(0, n, batch_size): batch_n = min(batch_size, n - batch_start) prompt = f"""You are generating training tasks for a multi-agent RL environment. Sector: {self.sector_name} Sector description: {self.sector_description} Available specialist roles: {roles_str} Generate exactly {batch_n} different {complexity_desc} task descriptions for this sector. Each task should: - Be 1-2 sentences long - Be specific and realistic for the {self.sector_name} sector - Potentially require one or more of the available specialists to complete - Vary in subject matter (don't repeat similar tasks) Return ONLY a JSON array of strings, no other text: ["task 1 description", "task 2 description", ...]""" try: import json response = self._client.chat.completions.create( model=self.llm_model, max_tokens=1200, messages=[{"role": "user", "content": prompt}], ) raw = response.choices[0].message.content.strip() raw = raw.replace("```json", "").replace("```", "").strip() task_strings = json.loads(raw) all_tasks.extend([ Task( description=t, complexity_class=complexity, domain=self.sector_name, ) for t in task_strings if isinstance(t, str) and len(t) > 10 ]) except Exception as e: print(f"[TaskBank] LLM generation failed for {complexity} batch: {e}. Using fallback.") all_tasks.extend(self._generate_catalog_tasks(complexity, batch_n)) return all_tasks def _generate_catalog_tasks(self, complexity: str, n: int) -> list[Task]: """ Fallback: derive tasks from specialist catalog without API calls. Produces formulaic but valid tasks for any sector. """ complexity_desc = self._complexity_descriptions.get(complexity, "a realistic") tasks = [] specialists = self._specialist_roles.copy() random.shuffle(specialists) for i in range(n): if len(specialists) >= 2: s1 = specialists[i % len(specialists)] s2 = specialists[(i + 1) % len(specialists)] desc = ( f"Design {complexity_desc} {self.sector_name} solution " f"involving {s1} and {s2} working together" ) else: s1 = specialists[0] if specialists else "specialist" desc = ( f"Create {complexity_desc} {self.sector_name} deliverable " f"for a {s1}" ) tasks.append(Task( description=desc, complexity_class=complexity, domain=self.sector_name, )) return tasks def sample(self) -> str: """ Sample a random task description for a new episode. Never blocks for a refill. When the cache drops below a low-water mark (10% of cache_size) a background thread is kicked off to replenish it. If the cache is completely empty (should only happen at init or after a phase switch drains it before the background fill completes) we fall back to a catalog-derived task immediately so reset() is never stalled. """ low_water = max(5, self.cache_size // 10) with self._cache_lock: if self._cache: task = self._cache.pop() else: task = None if task is None: # Cache exhausted — generate one catalog task inline (fast, no API) fallback = self._generate_catalog_tasks( random.choice(self._complexity_by_phase.get(self.phase, ["simple"])), 1 ) task_desc = fallback[0].description if fallback else ( f"Complete a {self.sector_name} task requiring specialist collaboration" ) self._refill_cache_background() return task_desc with self._cache_lock: cache_len = len(self._cache) if cache_len < low_water: self._refill_cache_background() return task.description def sample_task(self) -> Task: """Sample a full Task object.""" desc = self.sample() complexity = random.choice(self._complexity_by_phase.get(self.phase, ["simple"])) return Task(description=desc, complexity_class=complexity, domain=self.sector_name) def set_phase(self, phase: int) -> None: self.phase = phase with self._cache_lock: self._cache.clear() self._refill_running = False self._refill_cache() # synchronous — phase switches are rare and intentional @property def pool_size(self) -> int: return len(self._cache)