Spaces:
Running
Running
File size: 9,132 Bytes
a4f74f3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 | """
In-memory SQLite database for the buggy API.
Supports reset between episodes with DOMAIN RANDOMIZATION —
each seed produces different users, tasks, and data distributions
so that every training episode is unique.
"""
import random
import sqlite3
import threading
from contextlib import contextmanager
# Name pools for randomized seed data
FIRST_NAMES = [
"alice", "bob", "charlie", "diana", "ethan", "fiona", "george", "hannah",
"ivan", "julia", "kevin", "luna", "mike", "nina", "oscar", "priya",
"quinn", "ravi", "sara", "tom", "uma", "victor", "wendy", "xander",
]
DOMAINS = ["example.com", "company.org", "startup.io", "work.dev", "test.net"]
TASK_TITLES = [
"Setup CI/CD pipeline", "Write unit tests", "Fix login page CSS",
"Database migration", "API documentation", "Refactor auth module",
"Add rate limiting", "Setup monitoring", "Fix memory leak",
"Update dependencies", "Add logging middleware", "Create admin panel",
"Implement caching", "Fix CORS issues", "Add input validation",
"Setup Docker compose", "Write integration tests", "Fix date parsing bug",
"Add search functionality", "Implement pagination", "Setup SSL certs",
"Add webhook support", "Fix timezone handling", "Create backup script",
"Optimize database queries", "Add email notifications", "Fix file upload",
"Implement user roles", "Add audit logging", "Setup load balancer",
]
TASK_DESCRIPTIONS = [
"Configure GitHub Actions for automated deployment",
"Add tests for the auth module endpoints",
"Button alignment issue on mobile devices",
"Migrate from SQLite to PostgreSQL",
"Document all REST endpoints with examples",
"Break down the monolithic auth into smaller services",
"Prevent API abuse with request throttling",
"Setup Grafana dashboards for key metrics",
"Memory usage grows unbounded after 1000 requests",
"Several packages have critical CVEs",
"Add structured JSON logging to all routes",
"Build an admin dashboard for user management",
"Add Redis caching layer for frequent queries",
"Frontend gets blocked by CORS policy",
"Sanitize user inputs to prevent injection",
]
STATUSES = ["pending", "in_progress", "done"]
PRIORITIES = ["low", "medium", "high"]
class Database:
"""Thread-safe in-memory SQLite database that can be reset between episodes.
When a seed is provided, the database is populated with deterministically
randomized data — different users, tasks, and distributions each time.
This prevents the agent from memorizing a single fixed dataset.
"""
def __init__(self, seed: int | None = None):
self._lock = threading.Lock()
self._conn: sqlite3.Connection | None = None
self._seed = seed
self.initialize()
def initialize(self):
"""Create a fresh database with schema and seed data."""
with self._lock:
if self._conn:
self._conn.close()
self._conn = sqlite3.connect(":memory:", check_same_thread=False)
self._conn.row_factory = sqlite3.Row
self._conn.execute("PRAGMA journal_mode=WAL")
self._create_schema()
self._seed_data()
def _create_schema(self):
cursor = self._conn.cursor()
cursor.executescript("""
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT UNIQUE NOT NULL,
email TEXT NOT NULL,
password_hash TEXT NOT NULL,
role TEXT DEFAULT 'user',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE IF NOT EXISTS tasks (
id INTEGER PRIMARY KEY AUTOINCREMENT,
title TEXT NOT NULL,
description TEXT DEFAULT '',
status TEXT DEFAULT 'pending',
priority TEXT DEFAULT 'medium',
assignee_email TEXT DEFAULT '',
owner_id INTEGER NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (owner_id) REFERENCES users(id)
);
CREATE TABLE IF NOT EXISTS auth_tokens (
token TEXT PRIMARY KEY,
user_id INTEGER NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users(id)
);
""")
self._conn.commit()
def _seed_data(self):
"""Seed the database with randomized data based on the seed.
With seed=None, uses a fixed default dataset (for manual testing).
With a seed, generates random users/tasks so every episode differs.
"""
rng = random.Random(self._seed)
cursor = self._conn.cursor()
if self._seed is None:
# Default fixed data for manual testing / Gradio UI
cursor.executescript("""
INSERT INTO users (username, email, password_hash, role) VALUES
('alice', 'alice@example.com', 'hashed_password123', 'admin'),
('bob', 'bob@example.com', 'hashed_password123', 'user'),
('charlie', 'charlie@example.com', 'hashed_password123', 'user');
INSERT INTO tasks (title, description, status, priority, assignee_email, owner_id) VALUES
('Setup CI/CD pipeline', 'Configure GitHub Actions', 'in_progress', 'high', 'alice@example.com', 1),
('Write unit tests', 'Add tests for auth module', 'pending', 'medium', 'bob@example.com', 2),
('Fix login page CSS', 'Button alignment issue', 'done', 'low', 'charlie@example.com', 3),
('Database migration', 'Migrate to PostgreSQL', 'pending', 'high', 'alice@example.com', 1),
('API documentation', 'Document all endpoints', 'in_progress', 'medium', 'bob@example.com', 2);
""")
else:
# Randomized data — different every episode
# Pick 3-5 users from the name pool
num_users = rng.randint(3, 5)
user_names = rng.sample(FIRST_NAMES, num_users)
domain = rng.choice(DOMAINS)
# First user is always admin, rest are regular users
for i, name in enumerate(user_names):
role = "admin" if i == 0 else "user"
email = f"{name}@{domain}"
cursor.execute(
"INSERT INTO users (username, email, password_hash, role) VALUES (?, ?, ?, ?)",
(name, email, f"hashed_password_{rng.randint(100, 999)}", role),
)
# Pick 4-8 tasks with random assignments
num_tasks = rng.randint(4, 8)
task_titles = rng.sample(TASK_TITLES, min(num_tasks, len(TASK_TITLES)))
task_descs = rng.sample(TASK_DESCRIPTIONS, min(num_tasks, len(TASK_DESCRIPTIONS)))
for i in range(num_tasks):
owner_id = rng.randint(1, num_users)
assignee_id = rng.randint(1, num_users)
assignee_email = f"{user_names[assignee_id - 1]}@{domain}"
cursor.execute(
"INSERT INTO tasks (title, description, status, priority, assignee_email, owner_id) VALUES (?, ?, ?, ?, ?, ?)",
(
task_titles[i % len(task_titles)],
task_descs[i % len(task_descs)] if i < len(task_descs) else "",
rng.choice(STATUSES),
rng.choice(PRIORITIES),
assignee_email,
owner_id,
),
)
self._conn.commit()
@property
def user_names(self) -> list[str]:
"""Get usernames in the database (for the agent's observation)."""
rows = self.execute("SELECT username FROM users ORDER BY id")
return [r["username"] for r in rows]
@contextmanager
def get_cursor(self):
with self._lock:
cursor = self._conn.cursor()
try:
yield cursor
self._conn.commit()
except Exception:
self._conn.rollback()
raise
def execute(self, query: str, params: tuple = ()) -> list[dict]:
with self.get_cursor() as cursor:
cursor.execute(query, params)
if cursor.description:
columns = [col[0] for col in cursor.description]
return [dict(zip(columns, row)) for row in cursor.fetchall()]
return []
def execute_insert(self, query: str, params: tuple = ()) -> int:
with self.get_cursor() as cursor:
cursor.execute(query, params)
return cursor.lastrowid
def execute_update(self, query: str, params: tuple = ()) -> int:
with self.get_cursor() as cursor:
cursor.execute(query, params)
return cursor.rowcount
|