Spaces:
Running
Running
| # server/demo_agent.py | |
| # Simple rule-based demo agent for the Gradio UI. | |
| # Uses hardcoded heuristics to show the environment works without calling a real LLM. | |
| def demo_action(obs): | |
| """Generate a simple action based on observation. Used by the UI demo.""" | |
| task_type = obs.get('task_type', '') | |
| task_id = obs.get('task_id', '') | |
| turn = obs.get('turn', 0) | |
| if task_type == 'security': | |
| return _security_action(obs, task_id, turn) | |
| elif task_type == 'dependency': | |
| return _dependency_action(obs, task_id, turn) | |
| elif task_type == 'clinical': | |
| return _clinical_action(obs, task_id, turn) | |
| else: | |
| return {'action_type': 'invalid'} | |
| def _security_action(obs, task_id, turn): | |
| if turn == 0: | |
| tool_call = obs.get('tool_call', '') | |
| # Simple heuristic to detect common vulnerability types | |
| vuln_type = 'sql_injection' | |
| severity = 'critical' | |
| cvss = 8.5 | |
| if 'script' in tool_call.lower() or 'xss' in tool_call.lower(): | |
| vuln_type = 'xss' | |
| severity = 'medium' | |
| cvss = 5.0 | |
| elif 'password' in tool_call.lower() or 'secret' in tool_call.lower(): | |
| vuln_type = 'hardcoded_secret' | |
| severity = 'high' | |
| cvss = 6.5 | |
| elif 'jwt' in tool_call.lower() or 'token' in tool_call.lower(): | |
| vuln_type = 'jwt_misuse' | |
| severity = 'critical' | |
| cvss = 8.0 | |
| elif 'path' in tool_call.lower() or '..' in tool_call: | |
| vuln_type = 'path_traversal' | |
| severity = 'high' | |
| cvss = 7.0 | |
| elif 'auth' in tool_call.lower() and 'no' in tool_call.lower(): | |
| vuln_type = 'missing_auth' | |
| severity = 'critical' | |
| cvss = 8.5 | |
| return { | |
| 'action_type': 'identify_vulnerability', | |
| 'vuln_type': vuln_type, | |
| 'cvss_score': cvss, | |
| 'severity': severity, | |
| 'affected_line': 1, | |
| } | |
| elif 'reviewer_feedback' in obs: | |
| return { | |
| 'action_type': 'revise_fix', | |
| 'fix_code': 'sanitize_input(parameterized_query)', | |
| 'addressed_feedback': obs.get('reviewer_feedback', 'fixed the issue'), | |
| } | |
| else: | |
| return { | |
| 'action_type': 'propose_fix', | |
| 'fix_code': 'use parameterized query with ? placeholder', | |
| 'explanation': 'Replace string concatenation with parameterized queries', | |
| } | |
| def _dependency_action(obs, task_id, turn): | |
| task_subtype = obs.get('task_subtype', 'flag') | |
| if task_subtype == 'flag': | |
| return { | |
| 'action_type': 'flag_outdated', | |
| 'packages': {'torch': '1.9.0'}, | |
| 'deprecated_api': 'torch.autograd.Variable', | |
| 'replacement': 'plain tensor', | |
| } | |
| elif task_subtype == 'resolve': | |
| return { | |
| 'action_type': 'resolve_conflict', | |
| 'packages': {'torch': '2.1.0', 'numpy': '1.24.0'}, | |
| 'reasoning': 'PyTorch 2.1 requires NumPy 1.24+', | |
| } | |
| else: # migrate | |
| return { | |
| 'action_type': 'migrate_api', | |
| 'completed_items': ['break_001', 'break_002'], | |
| 'code_changes': { | |
| 'break_001': 'torch.where(condition, x*2, x)', | |
| 'break_002': 'x.shape[0]', | |
| }, | |
| } | |
| def _clinical_action(obs, task_id, turn): | |
| available_steps = obs.get('available_steps', []) | |
| if turn == 0: | |
| return { | |
| 'action_type': 'detect_gap', | |
| 'missing_steps': available_steps[:2] if available_steps else ['unknown_step'], | |
| 'risk_level': 'critical', | |
| } | |
| elif turn == 1: | |
| return { | |
| 'action_type': 'rank_issues', | |
| 'priority_order': available_steps[:3] if available_steps else ['unknown_step'], | |
| } | |
| else: | |
| dep_graph = obs.get('dependency_graph', {}) | |
| # Simple topological sort attempt | |
| ordered = _simple_topo_sort(available_steps, dep_graph) | |
| return { | |
| 'action_type': 'order_steps', | |
| 'recovery_steps': ordered, | |
| } | |
| def _simple_topo_sort(steps, dep_graph): | |
| """Simple topological sort for dependency ordering.""" | |
| if not dep_graph: | |
| return steps | |
| result = [] | |
| remaining = set(steps) | |
| for _ in range(len(steps) + 1): | |
| if not remaining: | |
| break | |
| for step in list(remaining): | |
| prereqs = dep_graph.get(step, []) | |
| if all(p in result for p in prereqs): | |
| result.append(step) | |
| remaining.remove(step) | |
| break | |
| # Add any unresolved steps | |
| result.extend(remaining) | |
| return result | |