File size: 4,701 Bytes
054d73a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5c2788
054d73a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""SQLite database for token storage."""

from __future__ import annotations

import aiosqlite
import json
import base64
import os
from pathlib import Path
from datetime import datetime
from typing import Optional

from cryptography.fernet import Fernet


DATABASE_PATH = Path(__file__).parent.parent / "classlens.db"


def get_encryption_key() -> bytes:
    """Get or generate encryption key."""
    key = os.getenv("ENCRYPTION_KEY", "")
    if not key:
        # Generate a new key for development (should be set in production)
        key = Fernet.generate_key().decode()
        print(f"⚠️  No ENCRYPTION_KEY set. Generated temporary key: {key}")
        print("⚠️  Set this in your .env file for persistent token storage.")
    return key.encode() if isinstance(key, str) else key


def encrypt_token(token_data: dict) -> str:
    """Encrypt token data for storage."""
    key = get_encryption_key()
    f = Fernet(key)
    json_data = json.dumps(token_data)
    encrypted = f.encrypt(json_data.encode())
    return base64.b64encode(encrypted).decode()


def decrypt_token(encrypted_data: str) -> dict:
    """Decrypt stored token data."""
    key = get_encryption_key()
    f = Fernet(key)
    encrypted = base64.b64decode(encrypted_data.encode())
    decrypted = f.decrypt(encrypted)
    return json.loads(decrypted.decode())


async def init_database():
    """Initialize the database with required tables."""
    async with aiosqlite.connect(DATABASE_PATH) as db:
        await db.execute("""
            CREATE TABLE IF NOT EXISTS oauth_tokens (
                teacher_email TEXT PRIMARY KEY,
                encrypted_tokens TEXT NOT NULL,
                created_at TEXT NOT NULL,
                updated_at TEXT NOT NULL
            )
        """)
        
        await db.execute("""
            CREATE TABLE IF NOT EXISTS analysis_reports (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                teacher_email TEXT NOT NULL,
                exam_title TEXT,
                report_markdown TEXT,
                report_json TEXT,
                created_at TEXT NOT NULL
            )
        """)
        
        await db.commit()


async def save_oauth_tokens(teacher_email: str, tokens: dict):
    """Save encrypted OAuth tokens for a teacher."""
    encrypted = encrypt_token(tokens)
    now = datetime.utcnow().isoformat()
    
    async with aiosqlite.connect(DATABASE_PATH) as db:
        await db.execute("""
            INSERT INTO oauth_tokens (teacher_email, encrypted_tokens, created_at, updated_at)
            VALUES (?, ?, ?, ?)
            ON CONFLICT(teacher_email) DO UPDATE SET
                encrypted_tokens = excluded.encrypted_tokens,
                updated_at = excluded.updated_at
        """, (teacher_email, encrypted, now, now))
        await db.commit()


async def get_oauth_tokens(teacher_email: str) -> Optional[dict]:
    """Retrieve OAuth tokens for a teacher."""
    async with aiosqlite.connect(DATABASE_PATH) as db:
        async with db.execute(
            "SELECT encrypted_tokens FROM oauth_tokens WHERE teacher_email = ?",
            (teacher_email,)
        ) as cursor:
            row = await cursor.fetchone()
            if row:
                return decrypt_token(row[0])
    return None


async def delete_oauth_tokens(teacher_email: str):
    """Delete OAuth tokens for a teacher."""
    async with aiosqlite.connect(DATABASE_PATH) as db:
        await db.execute(
            "DELETE FROM oauth_tokens WHERE teacher_email = ?",
            (teacher_email,)
        )
        await db.commit()


async def save_report(teacher_email: str, exam_title: str, report_markdown: str, report_json: str):
    """Save an analysis report."""
    now = datetime.utcnow().isoformat()
    
    async with aiosqlite.connect(DATABASE_PATH) as db:
        await db.execute("""
            INSERT INTO analysis_reports (teacher_email, exam_title, report_markdown, report_json, created_at)
            VALUES (?, ?, ?, ?, ?)
        """, (teacher_email, exam_title, report_markdown, report_json, now))
        await db.commit()


async def get_teacher_reports(teacher_email: str, limit: int = 10) -> list[dict]:
    """Get recent reports for a teacher."""
    async with aiosqlite.connect(DATABASE_PATH) as db:
        db.row_factory = aiosqlite.Row
        async with db.execute("""
            SELECT id, exam_title, report_markdown, report_json, created_at
            FROM analysis_reports
            WHERE teacher_email = ?
            ORDER BY created_at DESC
            LIMIT ?
        """, (teacher_email, limit)) as cursor:
            rows = await cursor.fetchall()
            return [dict(row) for row in rows]