File size: 4,926 Bytes
8971251
 
1cbfdb8
f62552c
 
 
1cbfdb8
f62552c
1cbfdb8
f62552c
c0f3a90
f62552c
 
 
 
 
1cbfdb8
 
 
 
 
c0f3a90
1cbfdb8
 
 
c0f3a90
 
 
 
 
 
1cbfdb8
 
c0f3a90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1cbfdb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a0caf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1cbfdb8
 
 
 
 
3a0caf3
 
1cbfdb8
 
 
 
8971251
934dccb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#!/bin/bash

# Initialize database (run migrations - this also creates default users)
echo "Running database migrations..."
python -m argilla_server database migrate

# Create custom admin user from env vars
if [ -n "$USERNAME" ] && [ -n "$PASSWORD" ]; then
    echo "Creating admin user: $USERNAME"
    python -m argilla_server database users create \
        --first-name "$USERNAME" \
        --username "$USERNAME" \
        --password "$PASSWORD" \
        --role owner || echo "User may already exist, continuing..."
fi

# Security hardening: delete default users and set custom API key
echo "Securing default user accounts..."
python3 << 'PYEOF'
import sqlite3
import os
import subprocess

# Find the Argilla SQLite database
db_path = None

# 1. Check ARGILLA_DATABASE_URL env var (e.g. sqlite+aiosqlite:///path/to/db)
db_url = os.environ.get("ARGILLA_DATABASE_URL", "")
if "sqlite" in db_url and "///" in db_url:
    candidate = db_url.split("///", 1)[1]
    candidate = os.path.expanduser(candidate)
    if os.path.exists(candidate):
        db_path = candidate

# 2. Common known paths
if not db_path:
    for candidate in [
        os.path.expanduser("~/.argilla/argilla.db"),
        "/home/argilla/.argilla/argilla.db",
        "/root/.argilla/argilla.db",
        "/tmp/argilla.db",
    ]:
        if os.path.exists(candidate):
            db_path = candidate
            break

# 3. Broad search as last resort
if not db_path:
    try:
        result = subprocess.run(
            ["find", "/", "-name", "argilla.db", "-type", "f"],
            capture_output=True, text=True, timeout=10
        )
        for line in result.stdout.strip().split("\n"):
            if line.strip():
                db_path = line.strip()
                break
    except Exception as e:
        print(f"Find search failed: {e}")

if not db_path:
    print("WARNING: Could not find Argilla database, skipping security hardening")
else:
    print(f"Found database at: {db_path}")
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    # Show current users
    cursor.execute("SELECT username, role FROM users")
    print(f"Users before hardening: {cursor.fetchall()}")

    username = os.environ.get("USERNAME", "")
    api_key = os.environ.get("ARGILLA_API_KEY", "")

    # Delete default users (owner, admin, argilla) but keep our custom user
    default_users = ["owner", "admin", "argilla"]
    for default_user in default_users:
        if default_user != username:
            cursor.execute("DELETE FROM users WHERE username = ?", (default_user,))
            if cursor.rowcount > 0:
                print(f"Deleted default user: {default_user}")

    # Set custom API key on our admin user so backup script can authenticate
    if api_key and username:
        cursor.execute("UPDATE users SET api_key = ? WHERE username = ?", (api_key, username))
        if cursor.rowcount > 0:
            print(f"Updated API key for user: {username}")

    # Create default workspace and assign our user to it
    if username:
        import uuid
        from datetime import datetime
        now = datetime.utcnow().isoformat()

        # Create 'argilla' workspace if it doesn't exist
        cursor.execute("SELECT id FROM workspaces WHERE name = 'argilla'")
        row = cursor.fetchone()
        if row:
            ws_id = row[0]
            print("Workspace 'argilla' already exists")
        else:
            ws_id = str(uuid.uuid4())
            cursor.execute(
                "INSERT INTO workspaces (id, name, inserted_at, updated_at) VALUES (?, 'argilla', ?, ?)",
                (ws_id, now, now)
            )
            print("Created workspace 'argilla'")

        # Assign user to workspace if not already
        cursor.execute("SELECT id FROM users WHERE username = ?", (username,))
        user_row = cursor.fetchone()
        if user_row:
            user_id = user_row[0]
            cursor.execute(
                "SELECT id FROM workspaces_users WHERE workspace_id = ? AND user_id = ?",
                (ws_id, user_id)
            )
            if not cursor.fetchone():
                wu_id = str(uuid.uuid4())
                cursor.execute(
                    "INSERT INTO workspaces_users (id, workspace_id, user_id, inserted_at, updated_at) VALUES (?, ?, ?, ?, ?)",
                    (wu_id, ws_id, user_id, now, now)
                )
                print(f"Assigned '{username}' to workspace 'argilla'")
            else:
                print(f"'{username}' already in workspace 'argilla'")

    conn.commit()

    # Verify
    cursor.execute("SELECT username, role FROM users")
    print(f"Users after hardening: {cursor.fetchall()}")
    cursor.execute("SELECT name FROM workspaces")
    print(f"Workspaces: {cursor.fetchall()}")
    conn.close()
PYEOF

echo "Security hardening complete."

# Start Argilla server
exec python -m argilla_server start --host 0.0.0.0 --port 6900