Spaces:
Running
Running
| """ | |
| db.py — persistent storage for Wardrobe Assistant on HF Spaces. | |
| Setup: | |
| 1. HF Space → Settings → Persistent Storage → Enable (mounts at /data) | |
| 2. Drop this file next to app.py | |
| 3. No extra pip packages needed — sqlite3 is stdlib | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import sqlite3 | |
| import uuid | |
| from contextlib import contextmanager | |
| from datetime import datetime, timezone, timedelta | |
| from typing import Any | |
| # --------------------------------------------------------------------------- | |
| # Path resolution | |
| # --------------------------------------------------------------------------- | |
| def _resolve_db_path() -> str: | |
| custom = os.getenv("DB_PATH") | |
| if custom: | |
| return custom | |
| if os.path.isdir("/data"): # HF Spaces persistent volume | |
| return "/data/wardrobe.db" | |
| return "./wardrobe.db" # local dev fallback | |
| DB_PATH = _resolve_db_path() | |
| # --------------------------------------------------------------------------- | |
| # Connection | |
| # --------------------------------------------------------------------------- | |
| def _conn(): | |
| """WAL mode + foreign keys. Auto-commit or rollback.""" | |
| con = sqlite3.connect(DB_PATH, check_same_thread=False) | |
| con.row_factory = sqlite3.Row | |
| con.execute("PRAGMA journal_mode=WAL") | |
| con.execute("PRAGMA foreign_keys=ON") | |
| try: | |
| yield con | |
| con.commit() | |
| except Exception: | |
| con.rollback() | |
| raise | |
| finally: | |
| con.close() | |
| # --------------------------------------------------------------------------- | |
| # Schema | |
| # --------------------------------------------------------------------------- | |
| def init_db() -> None: | |
| """Idempotent — safe to call on every startup.""" | |
| with _conn() as con: | |
| con.executescript(""" | |
| CREATE TABLE IF NOT EXISTS items ( | |
| id TEXT PRIMARY KEY, | |
| image_url TEXT NOT NULL DEFAULT '', | |
| category TEXT NOT NULL DEFAULT 'Unknown', | |
| color TEXT NOT NULL DEFAULT 'Unknown', | |
| pattern TEXT NOT NULL DEFAULT 'Solid', | |
| fabric TEXT NOT NULL DEFAULT 'Unknown', | |
| fit TEXT NOT NULL DEFAULT 'Unknown', | |
| season TEXT NOT NULL DEFAULT 'All-Season', | |
| style TEXT NOT NULL DEFAULT 'casual', | |
| type TEXT NOT NULL DEFAULT 'unknown', | |
| description TEXT NOT NULL DEFAULT '{}', | |
| created_at TEXT NOT NULL | |
| ); | |
| CREATE TABLE IF NOT EXISTS outfit_feedback ( | |
| id TEXT PRIMARY KEY, | |
| top_id TEXT NOT NULL, | |
| bottom_id TEXT NOT NULL, | |
| occasion TEXT NOT NULL DEFAULT 'casual', | |
| action TEXT NOT NULL CHECK(action IN ('wear','skip','save')), | |
| score INTEGER, | |
| created_at TEXT NOT NULL, | |
| FOREIGN KEY (top_id) REFERENCES items(id) ON DELETE CASCADE, | |
| FOREIGN KEY (bottom_id) REFERENCES items(id) ON DELETE CASCADE | |
| ); | |
| CREATE TABLE IF NOT EXISTS search_cache ( | |
| cache_key TEXT PRIMARY KEY, | |
| payload TEXT NOT NULL, | |
| created_at TEXT NOT NULL, | |
| expires_at TEXT NOT NULL | |
| ); | |
| CREATE INDEX IF NOT EXISTS idx_items_type ON items(type); | |
| CREATE INDEX IF NOT EXISTS idx_items_created ON items(created_at DESC); | |
| CREATE INDEX IF NOT EXISTS idx_fb_top ON outfit_feedback(top_id); | |
| CREATE INDEX IF NOT EXISTS idx_fb_bot ON outfit_feedback(bottom_id); | |
| """) | |
| # --------------------------------------------------------------------------- | |
| # Internal helpers | |
| # --------------------------------------------------------------------------- | |
| def _now() -> str: | |
| return datetime.now(timezone.utc).isoformat() | |
| def _row_to_item(row: sqlite3.Row) -> dict[str, Any]: | |
| d = dict(row) | |
| raw = d.get("description", "{}") | |
| try: | |
| d["description"] = json.loads(raw) if isinstance(raw, str) else raw | |
| except (json.JSONDecodeError, TypeError): | |
| d["description"] = {} | |
| return d | |
| # --------------------------------------------------------------------------- | |
| # Items CRUD | |
| # --------------------------------------------------------------------------- | |
| def item_insert(item: dict[str, Any]) -> dict[str, Any]: | |
| record = { | |
| "id": item.get("id") or str(uuid.uuid4()), | |
| "image_url": str(item.get("image_url") or ""), | |
| "category": str(item.get("category") or "Unknown"), | |
| "color": str(item.get("color") or "Unknown"), | |
| "pattern": str(item.get("pattern") or "Solid"), | |
| "fabric": str(item.get("fabric") or "Unknown"), | |
| "fit": str(item.get("fit") or "Unknown"), | |
| "season": str(item.get("season") or "All-Season"), | |
| "style": str(item.get("style") or "casual"), | |
| "type": str(item.get("type") or "unknown"), | |
| "description": json.dumps(item.get("description") or {}), | |
| "created_at": item.get("created_at") or _now(), | |
| } | |
| with _conn() as con: | |
| con.execute( | |
| """INSERT INTO items | |
| (id,image_url,category,color,pattern,fabric,fit, | |
| season,style,type,description,created_at) | |
| VALUES | |
| (:id,:image_url,:category,:color,:pattern,:fabric,:fit, | |
| :season,:style,:type,:description,:created_at)""", | |
| record, | |
| ) | |
| record["description"] = item.get("description") or {} | |
| return record | |
| def item_get_all() -> list[dict[str, Any]]: | |
| with _conn() as con: | |
| rows = con.execute( | |
| "SELECT * FROM items ORDER BY created_at DESC" | |
| ).fetchall() | |
| return [_row_to_item(r) for r in rows] | |
| def item_get(item_id: str) -> dict[str, Any] | None: | |
| with _conn() as con: | |
| row = con.execute( | |
| "SELECT * FROM items WHERE id=?", (item_id,) | |
| ).fetchone() | |
| return _row_to_item(row) if row else None | |
| def item_update(item_id: str, patch: dict[str, Any]) -> dict[str, Any] | None: | |
| existing = item_get(item_id) | |
| if existing is None: | |
| return None | |
| if isinstance(patch.get("description"), dict): | |
| merged = {**existing.get("description", {}), **patch["description"]} | |
| else: | |
| merged = existing.get("description", {}) | |
| allowed = {"image_url", "category", "color", "pattern", | |
| "fabric", "fit", "season", "style", "type"} | |
| updates = {k: str(v) for k, v in patch.items() if k in allowed} | |
| updates["description"] = json.dumps(merged) | |
| set_clause = ", ".join(f"{k}=:{k}" for k in updates) | |
| updates["id"] = item_id | |
| with _conn() as con: | |
| con.execute(f"UPDATE items SET {set_clause} WHERE id=:id", updates) | |
| return item_get(item_id) | |
| def item_delete(item_id: str) -> bool: | |
| with _conn() as con: | |
| cur = con.execute("DELETE FROM items WHERE id=?", (item_id,)) | |
| return cur.rowcount > 0 | |
| # --------------------------------------------------------------------------- | |
| # Feedback | |
| # --------------------------------------------------------------------------- | |
| def feedback_record( | |
| top_id: str, | |
| bottom_id: str, | |
| action: str, | |
| occasion: str = "casual", | |
| score: int | None = None, | |
| ) -> dict[str, Any]: | |
| rec = { | |
| "id": str(uuid.uuid4()), | |
| "top_id": top_id, | |
| "bottom_id": bottom_id, | |
| "occasion": occasion, | |
| "action": action, | |
| "score": score, | |
| "created_at": _now(), | |
| } | |
| with _conn() as con: | |
| con.execute( | |
| """INSERT INTO outfit_feedback | |
| (id,top_id,bottom_id,occasion,action,score,created_at) | |
| VALUES(:id,:top_id,:bottom_id,:occasion,:action,:score,:created_at)""", | |
| rec, | |
| ) | |
| return rec | |
| # --------------------------------------------------------------------------- | |
| # Search cache | |
| # --------------------------------------------------------------------------- | |
| def cache_get(key: str) -> Any | None: | |
| with _conn() as con: | |
| row = con.execute( | |
| "SELECT payload, expires_at FROM search_cache WHERE cache_key=?", | |
| (key,), | |
| ).fetchone() | |
| if not row: | |
| return None | |
| if row["expires_at"] < _now(): | |
| cache_delete(key) | |
| return None | |
| try: | |
| return json.loads(row["payload"]) | |
| except (json.JSONDecodeError, TypeError): | |
| return None | |
| def cache_set(key: str, payload: Any, ttl_seconds: int = 86_400) -> None: | |
| expires = (datetime.now(timezone.utc) + timedelta(seconds=ttl_seconds)).isoformat() | |
| with _conn() as con: | |
| con.execute( | |
| """INSERT INTO search_cache(cache_key,payload,created_at,expires_at) | |
| VALUES(?,?,?,?) | |
| ON CONFLICT(cache_key) DO UPDATE SET | |
| payload=excluded.payload, | |
| created_at=excluded.created_at, | |
| expires_at=excluded.expires_at""", | |
| (key, json.dumps(payload), _now(), expires), | |
| ) | |
| def cache_delete(key: str) -> None: | |
| with _conn() as con: | |
| con.execute("DELETE FROM search_cache WHERE cache_key=?", (key,)) | |
| def cache_purge_expired() -> int: | |
| with _conn() as con: | |
| cur = con.execute("DELETE FROM search_cache WHERE expires_at<?", (_now(),)) | |
| return cur.rowcount | |