Spaces:
Runtime error
Runtime error
File size: 5,028 Bytes
47ccba4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
"""
Database module for storing processed model architectures.
Uses SQLite for simple, file-based persistence.
"""
import json
import os
import sqlite3
from datetime import datetime
from pathlib import Path
from typing import List, Optional
from contextlib import contextmanager
# Database file path - use environment variable for Docker, fallback to local
DB_PATH = Path(os.environ.get("DATABASE_PATH", Path(__file__).parent.parent / "models.db"))
def get_connection():
"""Get a database connection."""
conn = sqlite3.connect(str(DB_PATH))
conn.row_factory = sqlite3.Row
return conn
@contextmanager
def get_db():
"""Context manager for database connections."""
conn = get_connection()
try:
yield conn
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
def init_db():
"""Initialize the database tables."""
with get_db() as conn:
conn.execute("""
CREATE TABLE IF NOT EXISTS saved_models (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
framework TEXT NOT NULL,
total_parameters INTEGER DEFAULT 0,
layer_count INTEGER DEFAULT 0,
architecture_json TEXT NOT NULL,
thumbnail TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
file_hash TEXT UNIQUE
)
""")
conn.execute("""
CREATE INDEX IF NOT EXISTS idx_name ON saved_models(name)
""")
conn.execute("""
CREATE INDEX IF NOT EXISTS idx_created_at ON saved_models(created_at)
""")
def save_model(
name: str,
framework: str,
total_parameters: int,
layer_count: int,
architecture: dict,
file_hash: Optional[str] = None,
thumbnail: Optional[str] = None
) -> int:
"""
Save a model architecture to the database.
Returns the saved model ID.
"""
architecture_json = json.dumps(architecture)
with get_db() as conn:
# Check if model with same hash already exists
if file_hash:
existing = conn.execute(
"SELECT id FROM saved_models WHERE file_hash = ?",
(file_hash,)
).fetchone()
if existing:
# Update existing entry
conn.execute("""
UPDATE saved_models
SET name = ?, framework = ?, total_parameters = ?,
layer_count = ?, architecture_json = ?,
thumbnail = ?, created_at = CURRENT_TIMESTAMP
WHERE file_hash = ?
""", (name, framework, total_parameters, layer_count,
architecture_json, thumbnail, file_hash))
return existing['id']
# Insert new entry
cursor = conn.execute("""
INSERT INTO saved_models
(name, framework, total_parameters, layer_count, architecture_json, file_hash, thumbnail)
VALUES (?, ?, ?, ?, ?, ?, ?)
""", (name, framework, total_parameters, layer_count,
architecture_json, file_hash, thumbnail))
return cursor.lastrowid
def get_saved_models() -> List[dict]:
"""Get all saved models (metadata only, not full architecture)."""
with get_db() as conn:
rows = conn.execute("""
SELECT id, name, framework, total_parameters, layer_count,
thumbnail, created_at
FROM saved_models
ORDER BY created_at DESC
""").fetchall()
return [dict(row) for row in rows]
def get_model_by_id(model_id: int) -> Optional[dict]:
"""Get a specific model with full architecture."""
with get_db() as conn:
row = conn.execute("""
SELECT id, name, framework, total_parameters, layer_count,
architecture_json, thumbnail, created_at
FROM saved_models
WHERE id = ?
""", (model_id,)).fetchone()
if row:
result = dict(row)
result['architecture'] = json.loads(result['architecture_json'])
del result['architecture_json']
return result
return None
def delete_model(model_id: int) -> bool:
"""Delete a model by ID. Returns True if deleted."""
with get_db() as conn:
cursor = conn.execute(
"DELETE FROM saved_models WHERE id = ?",
(model_id,)
)
return cursor.rowcount > 0
def model_exists_by_hash(file_hash: str) -> Optional[int]:
"""Check if a model with the given hash exists. Returns ID if exists."""
with get_db() as conn:
row = conn.execute(
"SELECT id FROM saved_models WHERE file_hash = ?",
(file_hash,)
).fetchone()
return row['id'] if row else None
# Initialize database on module load
init_db()
|