File size: 10,123 Bytes
59edb07
 
 
 
c72bb17
59edb07
 
 
c72bb17
59edb07
 
 
 
 
 
 
c72bb17
 
 
 
10292a0
 
59edb07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c72bb17
 
 
 
 
 
 
 
 
 
0e4c818
 
 
 
 
59edb07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32636fa
 
 
 
59edb07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c72bb17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e4c818
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
"""Database β€” SQLite persistence for simulation state."""

from __future__ import annotations

import hashlib
import json
import logging
import os
import secrets
from pathlib import Path
from typing import Optional

import aiosqlite

logger = logging.getLogger(__name__)


def _hash_password(password: str, salt: str) -> str:
    return hashlib.sha256(f"{salt}{password}".encode()).hexdigest()

# SOCI_DATA_DIR env var lets you point at a persistent disk (e.g. /var/data on Render).
DB_DIR = Path(os.environ.get("SOCI_DATA_DIR", "data"))
DEFAULT_DB = DB_DIR / "soci.db"

SCHEMA = """
CREATE TABLE IF NOT EXISTS snapshots (
    id INTEGER PRIMARY KEY AUTOINCREMENT,
    name TEXT NOT NULL,
    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
    tick INTEGER NOT NULL,
    day INTEGER NOT NULL,
    state_json TEXT NOT NULL
);

CREATE TABLE IF NOT EXISTS event_log (
    id INTEGER PRIMARY KEY AUTOINCREMENT,
    tick INTEGER NOT NULL,
    day INTEGER NOT NULL,
    time_str TEXT NOT NULL,
    event_type TEXT NOT NULL,
    agent_id TEXT,
    location TEXT,
    description TEXT NOT NULL,
    metadata_json TEXT
);

CREATE TABLE IF NOT EXISTS conversations (
    id INTEGER PRIMARY KEY AUTOINCREMENT,
    conv_id TEXT NOT NULL,
    tick INTEGER NOT NULL,
    day INTEGER NOT NULL,
    location TEXT NOT NULL,
    participants_json TEXT NOT NULL,
    topic TEXT,
    turns_json TEXT NOT NULL
);

CREATE INDEX IF NOT EXISTS idx_event_tick ON event_log(tick);
CREATE INDEX IF NOT EXISTS idx_event_agent ON event_log(agent_id);
CREATE INDEX IF NOT EXISTS idx_conv_tick ON conversations(tick);

CREATE TABLE IF NOT EXISTS users (
    id            INTEGER PRIMARY KEY AUTOINCREMENT,
    username      TEXT    NOT NULL UNIQUE,
    password_hash TEXT    NOT NULL,
    salt          TEXT    NOT NULL,
    token         TEXT    UNIQUE,
    agent_id      TEXT,
    created_at    TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);

CREATE TABLE IF NOT EXISTS settings (
    key   TEXT PRIMARY KEY,
    value TEXT NOT NULL
);
"""


class Database:
    """Async SQLite database for simulation persistence."""

    def __init__(self, db_path: str | Path = DEFAULT_DB) -> None:
        self.db_path = Path(db_path)
        self._db: Optional[aiosqlite.Connection] = None

    async def connect(self) -> None:
        """Connect to the database and create tables."""
        self.db_path.parent.mkdir(parents=True, exist_ok=True)
        self._db = await aiosqlite.connect(str(self.db_path))
        await self._db.executescript(SCHEMA)
        await self._db.commit()
        logger.info(f"Database connected: {self.db_path}")

    async def close(self) -> None:
        if self._db:
            await self._db.close()

    async def save_snapshot(self, name: str, tick: int, day: int, state: dict) -> int:
        """Save a full simulation state snapshot."""
        assert self._db is not None
        cursor = await self._db.execute(
            "INSERT INTO snapshots (name, tick, day, state_json) VALUES (?, ?, ?, ?)",
            (name, tick, day, json.dumps(state)),
        )
        await self._db.commit()
        return cursor.lastrowid

    async def load_snapshot(self, name: Optional[str] = None) -> Optional[dict]:
        """Load the latest snapshot, or a specific named one."""
        assert self._db is not None
        if name:
            cursor = await self._db.execute(
                "SELECT state_json FROM snapshots WHERE name = ? ORDER BY id DESC LIMIT 1",
                (name,),
            )
        else:
            cursor = await self._db.execute(
                "SELECT state_json FROM snapshots ORDER BY id DESC LIMIT 1",
            )
        row = await cursor.fetchone()
        if row:
            try:
                return json.loads(row[0])
            except (json.JSONDecodeError, ValueError) as e:
                logger.warning(f"Corrupt snapshot in DB, ignoring: {e}")
        return None

    async def list_snapshots(self) -> list[dict]:
        """List all saved snapshots."""
        assert self._db is not None
        cursor = await self._db.execute(
            "SELECT id, name, created_at, tick, day FROM snapshots ORDER BY id DESC"
        )
        rows = await cursor.fetchall()
        return [
            {"id": r[0], "name": r[1], "created_at": r[2], "tick": r[3], "day": r[4]}
            for r in rows
        ]

    async def log_event(
        self,
        tick: int,
        day: int,
        time_str: str,
        event_type: str,
        description: str,
        agent_id: str = "",
        location: str = "",
        metadata: Optional[dict] = None,
    ) -> None:
        """Log a simulation event."""
        assert self._db is not None
        await self._db.execute(
            "INSERT INTO event_log (tick, day, time_str, event_type, agent_id, location, description, metadata_json) "
            "VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
            (tick, day, time_str, event_type, agent_id, location, description,
             json.dumps(metadata) if metadata else None),
        )
        await self._db.commit()

    async def save_conversation(self, conv_data: dict) -> None:
        """Save a completed conversation."""
        assert self._db is not None
        await self._db.execute(
            "INSERT INTO conversations (conv_id, tick, day, location, participants_json, topic, turns_json) "
            "VALUES (?, ?, ?, ?, ?, ?, ?)",
            (
                conv_data["id"],
                conv_data["turns"][-1]["tick"] if conv_data["turns"] else 0,
                0,  # Day would be tracked from clock
                conv_data["location"],
                json.dumps(conv_data["participants"]),
                conv_data.get("topic", ""),
                json.dumps(conv_data["turns"]),
            ),
        )
        await self._db.commit()

    async def get_recent_events(self, limit: int = 50) -> list[dict]:
        """Get recent events from the log."""
        assert self._db is not None
        cursor = await self._db.execute(
            "SELECT tick, day, time_str, event_type, agent_id, location, description "
            "FROM event_log ORDER BY id DESC LIMIT ?",
            (limit,),
        )
        rows = await cursor.fetchall()
        return [
            {
                "tick": r[0], "day": r[1], "time_str": r[2],
                "event_type": r[3], "agent_id": r[4],
                "location": r[5], "description": r[6],
            }
            for r in rows
        ]

    # ── Auth / user methods ──────────────────────────────────────────────────

    async def create_user(self, username: str, password: str) -> dict:
        """Create a new user. Raises ValueError if username taken."""
        assert self._db is not None
        salt = secrets.token_hex(16)
        pw_hash = _hash_password(password, salt)
        token = secrets.token_hex(32)
        try:
            await self._db.execute(
                "INSERT INTO users (username, password_hash, salt, token) VALUES (?, ?, ?, ?)",
                (username, pw_hash, salt, token),
            )
            await self._db.commit()
        except aiosqlite.IntegrityError:
            raise ValueError(f"Username '{username}' is already taken")
        return {"username": username, "token": token, "agent_id": None}

    async def authenticate_user(self, username: str, password: str) -> Optional[dict]:
        """Verify credentials and return user dict with fresh token, or None."""
        assert self._db is not None
        cursor = await self._db.execute(
            "SELECT username, password_hash, salt, agent_id FROM users WHERE username = ?",
            (username,),
        )
        row = await cursor.fetchone()
        if not row:
            return None
        stored_hash = row[1]
        salt = row[2]
        if _hash_password(password, salt) != stored_hash:
            return None
        token = secrets.token_hex(32)
        await self._db.execute(
            "UPDATE users SET token = ? WHERE username = ?", (token, username)
        )
        await self._db.commit()
        return {"username": row[0], "token": token, "agent_id": row[3]}

    async def get_user_by_token(self, token: str) -> Optional[dict]:
        """Look up a user by their session token."""
        assert self._db is not None
        cursor = await self._db.execute(
            "SELECT username, agent_id FROM users WHERE token = ?", (token,)
        )
        row = await cursor.fetchone()
        if not row:
            return None
        return {"username": row[0], "agent_id": row[1]}

    async def set_user_agent(self, username: str, agent_id: str) -> None:
        """Link a player agent to a user account."""
        assert self._db is not None
        await self._db.execute(
            "UPDATE users SET agent_id = ? WHERE username = ?", (agent_id, username)
        )
        await self._db.commit()

    async def logout_user(self, token: str) -> None:
        """Invalidate a session token."""
        assert self._db is not None
        await self._db.execute("UPDATE users SET token = NULL WHERE token = ?", (token,))
        await self._db.commit()

    # ── Settings / persistent config ─────────────────────────────────────────

    async def get_setting(self, key: str, default: Optional[str] = None) -> Optional[str]:
        """Read a persisted setting by key."""
        assert self._db is not None
        cursor = await self._db.execute("SELECT value FROM settings WHERE key = ?", (key,))
        row = await cursor.fetchone()
        return row[0] if row else default

    async def set_setting(self, key: str, value: str) -> None:
        """Upsert a persisted setting."""
        assert self._db is not None
        await self._db.execute(
            "INSERT INTO settings (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = excluded.value",
            (key, value),
        )
        await self._db.commit()