| """parse_config tool — extract WorkloadConfig from a user's training script/args. |
| |
| Three input shapes are supported, dispatched by file extension: |
| |
| * ``.py`` — Python source. Walks the AST for HuggingFace |
| ``TrainingArguments(...)`` calls, ``DataLoader(...)`` calls, |
| ``torch.compile(...)`` calls, ``model.gradient_checkpointing_enable()`` |
| invocations, and ``os.environ["X"] = "..."`` assignments. |
| * ``.json`` — ``json.load`` of a dict shaped like ``TrainingArguments`` kwargs. |
| * ``.yaml`` / ``.yml`` — ``yaml.safe_load`` of the same shape. |
| |
| Before storing the source in ``WorkloadConfig.raw_source`` we run a regex |
| redaction pass for common secret shapes (HF/OpenAI/GitHub/Bearer tokens, |
| ``$HOME`` paths, ``s3://``/``ws[s]://`` URIs). Each scrubbed pattern is |
| recorded in ``WorkloadConfig.redactions`` so the report can flag what was |
| removed. |
| |
| Anything that doesn't map cleanly onto a ``WorkloadConfig`` field is shoved |
| into ``extras`` so we don't lose information. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import ast |
| import json |
| import re |
| from pathlib import Path |
| from typing import Any |
|
|
| import yaml |
|
|
| from agent.schemas import ToolResult, WorkloadConfig |
| from agent.tools import Tool |
|
|
|
|
| |
| |
| |
|
|
| |
| |
| _REDACTION_PATTERNS: list[tuple[str, re.Pattern[str]]] = [ |
| ("openai_key", re.compile(r"sk-[A-Za-z0-9]{20,}")), |
| ("hf_token", re.compile(r"hf_[A-Za-z0-9]{20,}")), |
| ("github_token", re.compile(r"gho_[A-Za-z0-9]{20,}")), |
| ("bearer_token", re.compile(r"Bearer\s+[A-Za-z0-9._\-]{8,}")), |
| ("home_path", re.compile(r"/home/[^/\s\"'`]+")), |
| ("s3_uri", re.compile(r"s3://[^\s\"'`]+")), |
| ("ws_uri", re.compile(r"wss?://[^\s\"'`]+")), |
| ] |
|
|
|
|
| def _redact(source: str) -> tuple[str, list[str]]: |
| """Apply every redaction pattern. Returns (clean_source, labels_hit).""" |
| labels: list[str] = [] |
| cleaned = source |
| for label, pattern in _REDACTION_PATTERNS: |
| if pattern.search(cleaned): |
| labels.append(label) |
| cleaned = pattern.sub(f"<REDACTED:{label}>", cleaned) |
| return cleaned, labels |
|
|
|
|
| def _redact_extras(extras: dict[str, Any]) -> list[str]: |
| """Mutate ``extras`` in place to scrub secret-shaped string values. |
| |
| Returns the labels that fired so we can merge them into ``redactions``. |
| Lists/dicts inside extras are walked recursively. Non-string leaves are |
| untouched. |
| """ |
| labels: list[str] = [] |
|
|
| def _scrub(value: Any) -> Any: |
| if isinstance(value, str): |
| cleaned, hits = _redact(value) |
| labels.extend(hits) |
| return cleaned |
| if isinstance(value, dict): |
| return {k: _scrub(v) for k, v in value.items()} |
| if isinstance(value, list): |
| return [_scrub(v) for v in value] |
| return value |
|
|
| for k, v in list(extras.items()): |
| extras[k] = _scrub(v) |
| return labels |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _literal(node: ast.AST) -> Any: |
| """Best-effort literal evaluation. Returns ``None`` for non-literal expressions.""" |
| try: |
| return ast.literal_eval(node) |
| except (ValueError, SyntaxError): |
| return None |
|
|
|
|
| def _attr_chain(node: ast.AST) -> str: |
| """Render ``a.b.c`` from a Name/Attribute chain. Empty string for everything else.""" |
| parts: list[str] = [] |
| cur = node |
| while isinstance(cur, ast.Attribute): |
| parts.append(cur.attr) |
| cur = cur.value |
| if isinstance(cur, ast.Name): |
| parts.append(cur.id) |
| return ".".join(reversed(parts)) |
| return "" |
|
|
|
|
| def _call_name(node: ast.Call) -> str: |
| """Resolve the dotted name of a Call's target (e.g. ``torch.compile``).""" |
| func = node.func |
| if isinstance(func, ast.Name): |
| return func.id |
| if isinstance(func, ast.Attribute): |
| return _attr_chain(func) |
| return "" |
|
|
|
|
| def _kwargs_to_dict( |
| node: ast.Call, dict_constants: dict[str, dict[str, Any]] | None = None |
| ) -> dict[str, Any]: |
| """Pull literal kwargs out of a Call. |
| |
| Also resolves ``**dict_var`` splats when the splat target was assigned a |
| literal dict elsewhere in the module (the resolved dict is in |
| ``dict_constants``). This is defensive against the common refactor pattern |
| ``_ta = dict(...); TrainingArguments(**_ta)`` — if the parser doesn't |
| follow the splat, every TrainingArguments field disappears and the agent |
| reasons over HF defaults instead of the script's actual values. |
| """ |
| out: dict[str, Any] = {} |
| dict_constants = dict_constants or {} |
| for kw in node.keywords: |
| if kw.arg is None: |
| |
| |
| if isinstance(kw.value, ast.Name): |
| resolved = dict_constants.get(kw.value.id) |
| if resolved: |
| for k, v in resolved.items(): |
| |
| out.setdefault(k, v) |
| continue |
| val = _literal(kw.value) |
| if val is not None or isinstance(kw.value, ast.Constant): |
| out[kw.arg] = val |
| return out |
|
|
|
|
| def _collect_dict_constants(tree: ast.AST) -> dict[str, dict[str, Any]]: |
| """Find module-level ``NAME = dict(k=v, ...)`` and ``NAME = {k: v, ...}`` |
| assignments where every value is a literal. Returns name → resolved dict. |
| |
| Used by ``_kwargs_to_dict`` to follow ``**NAME`` splats. |
| """ |
| constants: dict[str, dict[str, Any]] = {} |
| for stmt in getattr(tree, "body", []): |
| if not isinstance(stmt, ast.Assign): |
| continue |
| targets = [t for t in stmt.targets if isinstance(t, ast.Name)] |
| if not targets: |
| continue |
| resolved: dict[str, Any] | None = None |
| |
| if ( |
| isinstance(stmt.value, ast.Call) |
| and isinstance(stmt.value.func, ast.Name) |
| and stmt.value.func.id == "dict" |
| ): |
| resolved = {} |
| for kw in stmt.value.keywords: |
| if kw.arg is None: |
| continue |
| val = _literal(kw.value) |
| if val is not None or isinstance(kw.value, ast.Constant): |
| resolved[kw.arg] = val |
| |
| elif isinstance(stmt.value, ast.Dict): |
| resolved = {} |
| for k_node, v_node in zip(stmt.value.keys, stmt.value.values): |
| if not isinstance(k_node, ast.Constant) or not isinstance(k_node.value, str): |
| continue |
| val = _literal(v_node) |
| if val is not None or isinstance(v_node, ast.Constant): |
| resolved[k_node.value] = val |
| if resolved: |
| for t in targets: |
| constants[t.id] = resolved |
| return constants |
|
|
|
|
| |
| |
| |
|
|
| |
| _TRAINING_ARGS_MAP: dict[str, str] = { |
| "per_device_train_batch_size": "batch_size", |
| "gradient_accumulation_steps": "grad_accum_steps", |
| "max_seq_length": "seq_len", |
| "model_max_length": "seq_len", |
| "optim": "optimizer", |
| "gradient_checkpointing": "gradient_checkpointing", |
| "torch_compile": "torch_compile", |
| "learning_rate": "lr", |
| "warmup_steps": "warmup_steps", |
| "dataloader_num_workers": "dataloader_workers", |
| "dataloader_pin_memory": "dataloader_pin_memory", |
| "dataloader_prefetch_factor": "dataloader_prefetch_factor", |
| "dataloader_persistent_workers": "dataloader_persistent_workers", |
| "model_name_or_path": "model_name", |
| "model_name": "model_name", |
| } |
|
|
| _DATALOADER_MAP: dict[str, str] = { |
| "num_workers": "dataloader_workers", |
| "pin_memory": "dataloader_pin_memory", |
| "prefetch_factor": "dataloader_prefetch_factor", |
| "persistent_workers": "dataloader_persistent_workers", |
| } |
|
|
| _VALID_PRECISIONS = {"fp32", "fp16", "bf16", "fp8"} |
| _VALID_ATTENTION = {"sdpa", "flash", "flash_rocm", "eager", "unknown"} |
|
|
|
|
| def _coerce_precision(args: dict[str, Any]) -> str | None: |
| """Translate HF-style precision flags (bf16=True/fp16=True) into a Precision literal.""" |
| explicit = args.get("precision") |
| if isinstance(explicit, str) and explicit in _VALID_PRECISIONS: |
| return explicit |
| if args.get("bf16") is True: |
| return "bf16" |
| if args.get("fp16") is True: |
| return "fp16" |
| if args.get("tf32") is True: |
| return "fp32" |
| return None |
|
|
|
|
| def _coerce_attention(args: dict[str, Any]) -> str | None: |
| impl = args.get("attn_implementation") or args.get("attention_impl") |
| if not isinstance(impl, str): |
| return None |
| impl = impl.lower() |
| if impl in _VALID_ATTENTION: |
| return impl |
| if "flash" in impl and "rocm" in impl: |
| return "flash_rocm" |
| if "flash" in impl: |
| return "flash" |
| if impl == "eager": |
| return "eager" |
| if impl == "sdpa": |
| return "sdpa" |
| return None |
|
|
|
|
| def _apply_kwargs( |
| payload: dict[str, Any], |
| raw: dict[str, Any], |
| extras: dict[str, Any], |
| field_map: dict[str, str], |
| ) -> None: |
| """Push ``raw`` kwargs into ``payload`` using ``field_map``; unknown keys go to ``extras``.""" |
| for key, value in raw.items(): |
| if value is None: |
| continue |
| target = field_map.get(key) |
| if target is None: |
| extras[key] = value |
| continue |
| payload[target] = value |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _extract_from_python(source: str) -> tuple[dict[str, Any], dict[str, Any]]: |
| """Walk the AST. Returns (mapped_fields, extras_dict). |
| |
| Two-pass: first pass collects module-level ``NAME = "literal"`` assignments |
| so we can resolve identifiers passed positionally to ``from_pretrained`` |
| (the common ``MODEL_ID = "..."; from_pretrained(MODEL_ID)`` shape). |
| """ |
| payload: dict[str, Any] = {} |
| extras: dict[str, Any] = {} |
| env_vars: dict[str, str] = {} |
| raw_training_args: dict[str, Any] = {} |
| saw_torch_compile_call = False |
|
|
| tree = ast.parse(source) |
|
|
| |
| constants: dict[str, Any] = {} |
| for stmt in tree.body: |
| if isinstance(stmt, ast.Assign): |
| val = _literal(stmt.value) |
| if val is None: |
| continue |
| for target in stmt.targets: |
| if isinstance(target, ast.Name): |
| constants[target.id] = val |
|
|
| |
| |
| dict_constants = _collect_dict_constants(tree) |
|
|
| def _arg_value(node: ast.AST) -> Any: |
| """Literal eval, falling back to the constants table for bare Names.""" |
| val = _literal(node) |
| if val is not None: |
| return val |
| if isinstance(node, ast.Name) and node.id in constants: |
| return constants[node.id] |
| return None |
|
|
| |
| for node in ast.walk(tree): |
| if isinstance(node, ast.Call): |
| name = _call_name(node) |
| short = name.rsplit(".", 1)[-1] if name else "" |
|
|
| if short in ("TrainingArguments", "Seq2SeqTrainingArguments"): |
| kw = _kwargs_to_dict(node, dict_constants) |
| raw_training_args.update(kw) |
|
|
| elif short == "DataLoader": |
| kw = _kwargs_to_dict(node, dict_constants) |
| _apply_kwargs(payload, kw, extras, _DATALOADER_MAP) |
|
|
| elif name == "torch.compile" or ( |
| short == "compile" and isinstance(node.func, ast.Attribute) |
| ): |
| |
| |
| |
| saw_torch_compile_call = True |
|
|
| elif short == "gradient_checkpointing_enable": |
| payload["gradient_checkpointing"] = True |
|
|
| elif short == "from_pretrained": |
| if node.args: |
| val = _arg_value(node.args[0]) |
| if isinstance(val, str) and "model_name" not in payload: |
| payload["model_name"] = val |
| kw = _kwargs_to_dict(node) |
| attn = _coerce_attention(kw) |
| if attn: |
| payload["attention_impl"] = attn |
|
|
| elif short == "LoraConfig": |
| kw = _kwargs_to_dict(node) |
| if "r" in kw and isinstance(kw["r"], int): |
| payload["lora_rank"] = kw["r"] |
| for k, v in kw.items(): |
| if k != "r": |
| extras[f"lora.{k}"] = v |
|
|
| elif isinstance(node, ast.Assign): |
| |
| for target in node.targets: |
| if ( |
| isinstance(target, ast.Subscript) |
| and _attr_chain(target.value) == "os.environ" |
| ): |
| key = _literal(target.slice) |
| val = _literal(node.value) |
| if isinstance(key, str) and val is not None: |
| env_vars[key] = str(val) |
|
|
| if raw_training_args: |
| |
| prec = _coerce_precision(raw_training_args) |
| if prec: |
| payload["precision"] = prec |
| attn = _coerce_attention(raw_training_args) |
| if attn: |
| payload["attention_impl"] = attn |
| _apply_kwargs(payload, raw_training_args, extras, _TRAINING_ARGS_MAP) |
| |
| for k in ("bf16", "fp16", "tf32", "attn_implementation"): |
| extras.pop(k, None) |
|
|
| |
| |
| |
| if saw_torch_compile_call: |
| payload["torch_compile"] = True |
|
|
| if env_vars: |
| payload["env_vars"] = env_vars |
|
|
| return payload, extras |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _extract_from_dict(data: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: |
| """Translate a flat TrainingArguments-shaped dict into our schema.""" |
| payload: dict[str, Any] = {} |
| extras: dict[str, Any] = {} |
|
|
| prec = _coerce_precision(data) |
| if prec: |
| payload["precision"] = prec |
| attn = _coerce_attention(data) |
| if attn: |
| payload["attention_impl"] = attn |
|
|
| _apply_kwargs(payload, data, extras, _TRAINING_ARGS_MAP) |
|
|
| |
| dl = data.get("dataloader") if isinstance(data.get("dataloader"), dict) else None |
| if dl: |
| _apply_kwargs(payload, dl, extras, _DATALOADER_MAP) |
| extras.pop("dataloader", None) |
|
|
| |
| env = data.get("env_vars") or data.get("env") |
| if isinstance(env, dict): |
| payload["env_vars"] = {str(k): str(v) for k, v in env.items()} |
| extras.pop("env_vars", None) |
| extras.pop("env", None) |
|
|
| |
| for k in ("bf16", "fp16", "tf32", "attn_implementation"): |
| extras.pop(k, None) |
|
|
| return payload, extras |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _build_config( |
| payload: dict[str, Any], |
| extras: dict[str, Any], |
| raw_source: str, |
| redactions: list[str], |
| ) -> WorkloadConfig: |
| """Construct a WorkloadConfig from the extracted payload + extras + redacted source.""" |
| |
| |
| payload.setdefault("model_name", "unknown") |
| |
| int_fields = ( |
| "batch_size", |
| "grad_accum_steps", |
| "seq_len", |
| "warmup_steps", |
| "dataloader_workers", |
| "lora_rank", |
| "dataloader_prefetch_factor", |
| ) |
| for field in int_fields: |
| if field in payload and payload[field] is not None: |
| try: |
| payload[field] = int(payload[field]) |
| except (TypeError, ValueError): |
| extras[field] = payload.pop(field) |
| if "lr" in payload and payload["lr"] is not None: |
| try: |
| payload["lr"] = float(payload["lr"]) |
| except (TypeError, ValueError): |
| extras["lr"] = payload.pop("lr") |
| bool_fields = ( |
| "gradient_checkpointing", |
| "torch_compile", |
| "dataloader_pin_memory", |
| "dataloader_persistent_workers", |
| ) |
| for field in bool_fields: |
| if field in payload and payload[field] is not None: |
| payload[field] = bool(payload[field]) |
|
|
| payload["raw_source"] = raw_source |
| payload["redactions"] = redactions |
| payload["extras"] = extras |
| return WorkloadConfig(**payload) |
|
|
|
|
| def _parse_config_full(file_path: str) -> WorkloadConfig | ToolResult: |
| """Inner parser that returns the **complete** WorkloadConfig (including |
| ``raw_source``). Tests use this to verify redaction. The public |
| ``_parse_config`` wraps this and trims ``raw_source`` from the LLM- |
| facing tool result. |
| """ |
| try: |
| path = Path(file_path) |
| if not path.exists(): |
| return ToolResult(ok=False, error=f"File not found: {file_path}") |
|
|
| raw = path.read_text(encoding="utf-8") |
| clean, redactions = _redact(raw) |
| suffix = path.suffix.lower() |
|
|
| if suffix == ".py": |
| payload, extras = _extract_from_python(raw) |
| elif suffix == ".json": |
| data = json.loads(raw) |
| if not isinstance(data, dict): |
| return ToolResult( |
| ok=False, error="JSON config must be an object at the top level." |
| ) |
| payload, extras = _extract_from_dict(data) |
| elif suffix in (".yaml", ".yml"): |
| data = yaml.safe_load(raw) |
| if not isinstance(data, dict): |
| return ToolResult( |
| ok=False, error="YAML config must be a mapping at the top level." |
| ) |
| payload, extras = _extract_from_dict(data) |
| else: |
| return ToolResult( |
| ok=False, |
| error=f"Unsupported config extension '{suffix}'. Expected .py/.json/.yaml/.yml.", |
| ) |
|
|
| |
| |
| |
| extra_labels = _redact_extras(extras) |
| for label in extra_labels: |
| if label not in redactions: |
| redactions.append(label) |
|
|
| |
| |
| |
| env_vars = payload.get("env_vars") |
| if isinstance(env_vars, dict): |
| env_labels: list[str] = [] |
| cleaned_env: dict[str, str] = {} |
| for k, v in env_vars.items(): |
| if isinstance(v, str): |
| cleaned, hits = _redact(v) |
| cleaned_env[k] = cleaned |
| env_labels.extend(hits) |
| else: |
| cleaned_env[k] = v |
| payload["env_vars"] = cleaned_env |
| for label in env_labels: |
| if label not in redactions: |
| redactions.append(label) |
|
|
| return _build_config(payload, extras, clean, redactions) |
| except SyntaxError as exc: |
| return ToolResult(ok=False, error=f"Python parse error: {exc}") |
| except json.JSONDecodeError as exc: |
| return ToolResult(ok=False, error=f"JSON parse error: {exc}") |
| except yaml.YAMLError as exc: |
| return ToolResult(ok=False, error=f"YAML parse error: {exc}") |
| except Exception as exc: |
| return ToolResult(ok=False, error=f"{type(exc).__name__}: {exc}") |
|
|
|
|
| def _parse_config(file_path: str) -> ToolResult: |
| """Parse a training-config artefact into a WorkloadConfig and return it |
| as a tool result. |
| |
| Detects format from the path suffix; ``.py`` runs through the AST visitor, |
| ``.json`` and ``.yaml``/``.yml`` go through the dict path. Anything else is |
| rejected up front rather than guessed. |
| |
| The returned dict **omits** ``raw_source`` to keep the LLM's audit |
| conversation small enough to fit Qwen's context window. The redacted |
| source still lives in ``WorkloadConfig.raw_source`` server-side; tests |
| inspect it via ``_parse_config_full``. |
| """ |
| cfg_or_err = _parse_config_full(file_path) |
| if isinstance(cfg_or_err, ToolResult): |
| return cfg_or_err |
| return ToolResult(ok=True, result=cfg_or_err.model_dump(exclude={"raw_source"})) |
|
|
|
|
| PARSE_CONFIG = Tool( |
| name="parse_config", |
| description=( |
| "Parse a user-uploaded training script or HF TrainingArguments JSON/YAML " |
| "into a normalized WorkloadConfig. Redacts tokens, keys, and filesystem " |
| "paths before returning." |
| ), |
| input_schema={ |
| "type": "object", |
| "properties": { |
| "file_path": { |
| "type": "string", |
| "description": "Absolute path to the uploaded training script or config file.", |
| } |
| }, |
| "required": ["file_path"], |
| }, |
| fn=_parse_config, |
| ) |
|
|