"""Database for conversations and distillation data""" import sqlite3 from datetime import datetime from typing import List, Dict from config import DATABASE_PATH class VedaDatabase: """Database handler with distillation support""" def __init__(self): self._init_db() def _get_conn(self): conn = sqlite3.connect(DATABASE_PATH) conn.row_factory = sqlite3.Row return conn def _init_db(self): conn = self._get_conn() cursor = conn.cursor() # Regular conversations table cursor.execute(''' CREATE TABLE IF NOT EXISTS conversations ( id INTEGER PRIMARY KEY AUTOINCREMENT, timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, user_input TEXT NOT NULL, assistant_response TEXT NOT NULL, feedback INTEGER DEFAULT 0 ) ''') # Distillation data table (teacher responses) cursor.execute(''' CREATE TABLE IF NOT EXISTS distillation_data ( id INTEGER PRIMARY KEY AUTOINCREMENT, timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, user_input TEXT NOT NULL, teacher_response TEXT NOT NULL, student_response TEXT, used_for_training BOOLEAN DEFAULT 0, quality_score REAL DEFAULT 0 ) ''') # Training history cursor.execute(''' CREATE TABLE IF NOT EXISTS training_history ( id INTEGER PRIMARY KEY AUTOINCREMENT, timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, training_type TEXT, samples_used INTEGER, epochs INTEGER, final_loss REAL ) ''') conn.commit() conn.close() # ===== Conversations ===== def save_conversation(self, user_input: str, response: str) -> int: conn = self._get_conn() cursor = conn.cursor() cursor.execute(''' INSERT INTO conversations (user_input, assistant_response) VALUES (?, ?) ''', (user_input, response)) conv_id = cursor.lastrowid conn.commit() conn.close() return conv_id def update_feedback(self, conv_id: int, feedback: int): conn = self._get_conn() cursor = conn.cursor() cursor.execute(''' UPDATE conversations SET feedback = ? WHERE id = ? ''', (feedback, conv_id)) conn.commit() conn.close() def get_good_conversations(self, limit: int = 100) -> List[Dict]: conn = self._get_conn() cursor = conn.cursor() cursor.execute(''' SELECT user_input, assistant_response FROM conversations WHERE feedback > 0 ORDER BY timestamp DESC LIMIT ? ''', (limit,)) rows = cursor.fetchall() conn.close() return [dict(row) for row in rows] # ===== Distillation ===== def save_distillation_data( self, user_input: str, teacher_response: str, student_response: str = None, quality_score: float = 0.0 ) -> int: conn = self._get_conn() cursor = conn.cursor() cursor.execute(''' INSERT INTO distillation_data (user_input, teacher_response, student_response, quality_score) VALUES (?, ?, ?, ?) ''', (user_input, teacher_response, student_response, quality_score)) data_id = cursor.lastrowid conn.commit() conn.close() return data_id def get_unused_distillation_data(self, limit: int = 500) -> List[Dict]: """Get teacher responses not yet used for training""" conn = self._get_conn() cursor = conn.cursor() cursor.execute(''' SELECT id, user_input, teacher_response FROM distillation_data WHERE used_for_training = 0 ORDER BY timestamp DESC LIMIT ? ''', (limit,)) rows = cursor.fetchall() conn.close() return [dict(row) for row in rows] def mark_distillation_used(self, ids: List[int]): """Mark distillation data as used for training""" conn = self._get_conn() cursor = conn.cursor() placeholders = ",".join("?" * len(ids)) cursor.execute(f''' UPDATE distillation_data SET used_for_training = 1 WHERE id IN ({placeholders}) ''', ids) conn.commit() conn.close() def get_distillation_count(self) -> Dict: """Get count of distillation data""" conn = self._get_conn() cursor = conn.cursor() cursor.execute('SELECT COUNT(*) FROM distillation_data') total = cursor.fetchone()[0] cursor.execute('SELECT COUNT(*) FROM distillation_data WHERE used_for_training = 0') unused = cursor.fetchone()[0] cursor.execute('SELECT COUNT(*) FROM distillation_data WHERE used_for_training = 1') used = cursor.fetchone()[0] conn.close() return {"total": total, "unused": unused, "used": used} # ===== Stats ===== def get_stats(self) -> Dict: conn = self._get_conn() cursor = conn.cursor() cursor.execute('SELECT COUNT(*) FROM conversations') total = cursor.fetchone()[0] cursor.execute('SELECT COUNT(*) FROM conversations WHERE feedback > 0') positive = cursor.fetchone()[0] cursor.execute('SELECT COUNT(*) FROM conversations WHERE feedback < 0') negative = cursor.fetchone()[0] distill = self.get_distillation_count() conn.close() return { "total": total, "positive": positive, "negative": negative, "distillation_total": distill["total"], "distillation_unused": distill["unused"], } def save_training_history( self, training_type: str, samples_used: int, epochs: int, final_loss: float ): conn = self._get_conn() cursor = conn.cursor() cursor.execute(''' INSERT INTO training_history (training_type, samples_used, epochs, final_loss) VALUES (?, ?, ?, ?) ''', (training_type, samples_used, epochs, final_loss)) conn.commit() conn.close() db = VedaDatabase()