Spaces:
Runtime error
Runtime error
| """ | |
| Database module for storing processed model architectures. | |
| Uses SQLite for simple, file-based persistence. | |
| """ | |
| import json | |
| import os | |
| import sqlite3 | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import List, Optional | |
| from contextlib import contextmanager | |
| # Database file path - use environment variable for Docker, fallback to local | |
| DB_PATH = Path(os.environ.get("DATABASE_PATH", Path(__file__).parent.parent / "models.db")) | |
| def get_connection(): | |
| """Get a database connection.""" | |
| conn = sqlite3.connect(str(DB_PATH)) | |
| conn.row_factory = sqlite3.Row | |
| return conn | |
| def get_db(): | |
| """Context manager for database connections.""" | |
| conn = get_connection() | |
| try: | |
| yield conn | |
| conn.commit() | |
| except Exception: | |
| conn.rollback() | |
| raise | |
| finally: | |
| conn.close() | |
| def init_db(): | |
| """Initialize the database tables.""" | |
| with get_db() as conn: | |
| conn.execute(""" | |
| CREATE TABLE IF NOT EXISTS saved_models ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| name TEXT NOT NULL, | |
| framework TEXT NOT NULL, | |
| total_parameters INTEGER DEFAULT 0, | |
| layer_count INTEGER DEFAULT 0, | |
| architecture_json TEXT NOT NULL, | |
| thumbnail TEXT, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| file_hash TEXT UNIQUE | |
| ) | |
| """) | |
| conn.execute(""" | |
| CREATE INDEX IF NOT EXISTS idx_name ON saved_models(name) | |
| """) | |
| conn.execute(""" | |
| CREATE INDEX IF NOT EXISTS idx_created_at ON saved_models(created_at) | |
| """) | |
| def save_model( | |
| name: str, | |
| framework: str, | |
| total_parameters: int, | |
| layer_count: int, | |
| architecture: dict, | |
| file_hash: Optional[str] = None, | |
| thumbnail: Optional[str] = None | |
| ) -> int: | |
| """ | |
| Save a model architecture to the database. | |
| Returns the saved model ID. | |
| """ | |
| architecture_json = json.dumps(architecture) | |
| with get_db() as conn: | |
| # Check if model with same hash already exists | |
| if file_hash: | |
| existing = conn.execute( | |
| "SELECT id FROM saved_models WHERE file_hash = ?", | |
| (file_hash,) | |
| ).fetchone() | |
| if existing: | |
| # Update existing entry | |
| conn.execute(""" | |
| UPDATE saved_models | |
| SET name = ?, framework = ?, total_parameters = ?, | |
| layer_count = ?, architecture_json = ?, | |
| thumbnail = ?, created_at = CURRENT_TIMESTAMP | |
| WHERE file_hash = ? | |
| """, (name, framework, total_parameters, layer_count, | |
| architecture_json, thumbnail, file_hash)) | |
| return existing['id'] | |
| # Insert new entry | |
| cursor = conn.execute(""" | |
| INSERT INTO saved_models | |
| (name, framework, total_parameters, layer_count, architecture_json, file_hash, thumbnail) | |
| VALUES (?, ?, ?, ?, ?, ?, ?) | |
| """, (name, framework, total_parameters, layer_count, | |
| architecture_json, file_hash, thumbnail)) | |
| return cursor.lastrowid | |
| def get_saved_models() -> List[dict]: | |
| """Get all saved models (metadata only, not full architecture).""" | |
| with get_db() as conn: | |
| rows = conn.execute(""" | |
| SELECT id, name, framework, total_parameters, layer_count, | |
| thumbnail, created_at | |
| FROM saved_models | |
| ORDER BY created_at DESC | |
| """).fetchall() | |
| return [dict(row) for row in rows] | |
| def get_model_by_id(model_id: int) -> Optional[dict]: | |
| """Get a specific model with full architecture.""" | |
| with get_db() as conn: | |
| row = conn.execute(""" | |
| SELECT id, name, framework, total_parameters, layer_count, | |
| architecture_json, thumbnail, created_at | |
| FROM saved_models | |
| WHERE id = ? | |
| """, (model_id,)).fetchone() | |
| if row: | |
| result = dict(row) | |
| result['architecture'] = json.loads(result['architecture_json']) | |
| del result['architecture_json'] | |
| return result | |
| return None | |
| def delete_model(model_id: int) -> bool: | |
| """Delete a model by ID. Returns True if deleted.""" | |
| with get_db() as conn: | |
| cursor = conn.execute( | |
| "DELETE FROM saved_models WHERE id = ?", | |
| (model_id,) | |
| ) | |
| return cursor.rowcount > 0 | |
| def model_exists_by_hash(file_hash: str) -> Optional[int]: | |
| """Check if a model with the given hash exists. Returns ID if exists.""" | |
| with get_db() as conn: | |
| row = conn.execute( | |
| "SELECT id FROM saved_models WHERE file_hash = ?", | |
| (file_hash,) | |
| ).fetchone() | |
| return row['id'] if row else None | |
| # Initialize database on module load | |
| init_db() | |