File size: 5,028 Bytes
47ccba4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
"""
Database module for storing processed model architectures.
Uses SQLite for simple, file-based persistence.
"""

import json
import os
import sqlite3
from datetime import datetime
from pathlib import Path
from typing import List, Optional
from contextlib import contextmanager

# Database file path - use environment variable for Docker, fallback to local
DB_PATH = Path(os.environ.get("DATABASE_PATH", Path(__file__).parent.parent / "models.db"))


def get_connection():
    """Get a database connection."""
    conn = sqlite3.connect(str(DB_PATH))
    conn.row_factory = sqlite3.Row
    return conn


@contextmanager
def get_db():
    """Context manager for database connections."""
    conn = get_connection()
    try:
        yield conn
        conn.commit()
    except Exception:
        conn.rollback()
        raise
    finally:
        conn.close()


def init_db():
    """Initialize the database tables."""
    with get_db() as conn:
        conn.execute("""
            CREATE TABLE IF NOT EXISTS saved_models (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                name TEXT NOT NULL,
                framework TEXT NOT NULL,
                total_parameters INTEGER DEFAULT 0,
                layer_count INTEGER DEFAULT 0,
                architecture_json TEXT NOT NULL,
                thumbnail TEXT,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                file_hash TEXT UNIQUE
            )
        """)
        conn.execute("""
            CREATE INDEX IF NOT EXISTS idx_name ON saved_models(name)
        """)
        conn.execute("""
            CREATE INDEX IF NOT EXISTS idx_created_at ON saved_models(created_at)
        """)


def save_model(
    name: str,
    framework: str,
    total_parameters: int,
    layer_count: int,
    architecture: dict,
    file_hash: Optional[str] = None,
    thumbnail: Optional[str] = None
) -> int:
    """
    Save a model architecture to the database.
    Returns the saved model ID.
    """
    architecture_json = json.dumps(architecture)
    
    with get_db() as conn:
        # Check if model with same hash already exists
        if file_hash:
            existing = conn.execute(
                "SELECT id FROM saved_models WHERE file_hash = ?",
                (file_hash,)
            ).fetchone()
            
            if existing:
                # Update existing entry
                conn.execute("""
                    UPDATE saved_models 
                    SET name = ?, framework = ?, total_parameters = ?, 
                        layer_count = ?, architecture_json = ?, 
                        thumbnail = ?, created_at = CURRENT_TIMESTAMP
                    WHERE file_hash = ?
                """, (name, framework, total_parameters, layer_count, 
                      architecture_json, thumbnail, file_hash))
                return existing['id']
        
        # Insert new entry
        cursor = conn.execute("""
            INSERT INTO saved_models 
            (name, framework, total_parameters, layer_count, architecture_json, file_hash, thumbnail)
            VALUES (?, ?, ?, ?, ?, ?, ?)
        """, (name, framework, total_parameters, layer_count, 
              architecture_json, file_hash, thumbnail))
        
        return cursor.lastrowid


def get_saved_models() -> List[dict]:
    """Get all saved models (metadata only, not full architecture)."""
    with get_db() as conn:
        rows = conn.execute("""
            SELECT id, name, framework, total_parameters, layer_count, 
                   thumbnail, created_at
            FROM saved_models
            ORDER BY created_at DESC
        """).fetchall()
        
        return [dict(row) for row in rows]


def get_model_by_id(model_id: int) -> Optional[dict]:
    """Get a specific model with full architecture."""
    with get_db() as conn:
        row = conn.execute("""
            SELECT id, name, framework, total_parameters, layer_count,
                   architecture_json, thumbnail, created_at
            FROM saved_models
            WHERE id = ?
        """, (model_id,)).fetchone()
        
        if row:
            result = dict(row)
            result['architecture'] = json.loads(result['architecture_json'])
            del result['architecture_json']
            return result
        
        return None


def delete_model(model_id: int) -> bool:
    """Delete a model by ID. Returns True if deleted."""
    with get_db() as conn:
        cursor = conn.execute(
            "DELETE FROM saved_models WHERE id = ?",
            (model_id,)
        )
        return cursor.rowcount > 0


def model_exists_by_hash(file_hash: str) -> Optional[int]:
    """Check if a model with the given hash exists. Returns ID if exists."""
    with get_db() as conn:
        row = conn.execute(
            "SELECT id FROM saved_models WHERE file_hash = ?",
            (file_hash,)
        ).fetchone()
        
        return row['id'] if row else None


# Initialize database on module load
init_db()