vedaco commited on
Commit
0fe7d00
·
verified ·
1 Parent(s): d1e46fe

Update database.py

Browse files
Files changed (1) hide show
  1. database.py +154 -25
database.py CHANGED
@@ -1,25 +1,27 @@
1
- """Database for conversations"""
2
 
3
  import sqlite3
 
4
  from typing import List, Dict
5
  from config import DATABASE_PATH
6
 
7
 
8
  class VedaDatabase:
9
- """Database handler"""
10
-
11
  def __init__(self):
12
  self._init_db()
13
-
14
  def _get_conn(self):
15
  conn = sqlite3.connect(DATABASE_PATH)
16
  conn.row_factory = sqlite3.Row
17
  return conn
18
-
19
  def _init_db(self):
20
  conn = self._get_conn()
21
  cursor = conn.cursor()
22
-
 
23
  cursor.execute('''
24
  CREATE TABLE IF NOT EXISTS conversations (
25
  id INTEGER PRIMARY KEY AUTOINCREMENT,
@@ -29,40 +31,66 @@ class VedaDatabase:
29
  feedback INTEGER DEFAULT 0
30
  )
31
  ''')
32
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  conn.commit()
34
  conn.close()
35
-
 
36
  def save_conversation(self, user_input: str, response: str) -> int:
37
  conn = self._get_conn()
38
  cursor = conn.cursor()
39
-
40
  cursor.execute('''
41
  INSERT INTO conversations (user_input, assistant_response)
42
  VALUES (?, ?)
43
  ''', (user_input, response))
44
-
45
  conv_id = cursor.lastrowid
46
  conn.commit()
47
  conn.close()
48
-
49
  return conv_id
50
-
51
  def update_feedback(self, conv_id: int, feedback: int):
52
  conn = self._get_conn()
53
  cursor = conn.cursor()
54
-
55
  cursor.execute('''
56
  UPDATE conversations SET feedback = ? WHERE id = ?
57
  ''', (feedback, conv_id))
58
-
59
  conn.commit()
60
  conn.close()
61
-
62
  def get_good_conversations(self, limit: int = 100) -> List[Dict]:
63
  conn = self._get_conn()
64
  cursor = conn.cursor()
65
-
66
  cursor.execute('''
67
  SELECT user_input, assistant_response
68
  FROM conversations
@@ -70,28 +98,129 @@ class VedaDatabase:
70
  ORDER BY timestamp DESC
71
  LIMIT ?
72
  ''', (limit,))
73
-
74
  rows = cursor.fetchall()
75
  conn.close()
76
-
77
  return [dict(row) for row in rows]
78
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  def get_stats(self) -> Dict:
80
  conn = self._get_conn()
81
  cursor = conn.cursor()
82
-
83
  cursor.execute('SELECT COUNT(*) FROM conversations')
84
  total = cursor.fetchone()[0]
85
-
86
  cursor.execute('SELECT COUNT(*) FROM conversations WHERE feedback > 0')
87
  positive = cursor.fetchone()[0]
88
-
89
  cursor.execute('SELECT COUNT(*) FROM conversations WHERE feedback < 0')
90
  negative = cursor.fetchone()[0]
91
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  conn.close()
93
-
94
- return {'total': total, 'positive': positive, 'negative': negative}
95
 
96
 
97
  db = VedaDatabase()
 
1
+ """Database for conversations and distillation data"""
2
 
3
  import sqlite3
4
+ from datetime import datetime
5
  from typing import List, Dict
6
  from config import DATABASE_PATH
7
 
8
 
9
  class VedaDatabase:
10
+ """Database handler with distillation support"""
11
+
12
  def __init__(self):
13
  self._init_db()
14
+
15
  def _get_conn(self):
16
  conn = sqlite3.connect(DATABASE_PATH)
17
  conn.row_factory = sqlite3.Row
18
  return conn
19
+
20
  def _init_db(self):
21
  conn = self._get_conn()
22
  cursor = conn.cursor()
23
+
24
+ # Regular conversations table
25
  cursor.execute('''
26
  CREATE TABLE IF NOT EXISTS conversations (
27
  id INTEGER PRIMARY KEY AUTOINCREMENT,
 
31
  feedback INTEGER DEFAULT 0
32
  )
33
  ''')
34
+
35
+ # Distillation data table (teacher responses)
36
+ cursor.execute('''
37
+ CREATE TABLE IF NOT EXISTS distillation_data (
38
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
39
+ timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
40
+ user_input TEXT NOT NULL,
41
+ teacher_response TEXT NOT NULL,
42
+ student_response TEXT,
43
+ used_for_training BOOLEAN DEFAULT 0,
44
+ quality_score REAL DEFAULT 0
45
+ )
46
+ ''')
47
+
48
+ # Training history
49
+ cursor.execute('''
50
+ CREATE TABLE IF NOT EXISTS training_history (
51
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
52
+ timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
53
+ training_type TEXT,
54
+ samples_used INTEGER,
55
+ epochs INTEGER,
56
+ final_loss REAL
57
+ )
58
+ ''')
59
+
60
  conn.commit()
61
  conn.close()
62
+
63
+ # ===== Conversations =====
64
  def save_conversation(self, user_input: str, response: str) -> int:
65
  conn = self._get_conn()
66
  cursor = conn.cursor()
67
+
68
  cursor.execute('''
69
  INSERT INTO conversations (user_input, assistant_response)
70
  VALUES (?, ?)
71
  ''', (user_input, response))
72
+
73
  conv_id = cursor.lastrowid
74
  conn.commit()
75
  conn.close()
76
+
77
  return conv_id
78
+
79
  def update_feedback(self, conv_id: int, feedback: int):
80
  conn = self._get_conn()
81
  cursor = conn.cursor()
82
+
83
  cursor.execute('''
84
  UPDATE conversations SET feedback = ? WHERE id = ?
85
  ''', (feedback, conv_id))
86
+
87
  conn.commit()
88
  conn.close()
89
+
90
  def get_good_conversations(self, limit: int = 100) -> List[Dict]:
91
  conn = self._get_conn()
92
  cursor = conn.cursor()
93
+
94
  cursor.execute('''
95
  SELECT user_input, assistant_response
96
  FROM conversations
 
98
  ORDER BY timestamp DESC
99
  LIMIT ?
100
  ''', (limit,))
101
+
102
  rows = cursor.fetchall()
103
  conn.close()
104
+
105
  return [dict(row) for row in rows]
106
+
107
+ # ===== Distillation =====
108
+ def save_distillation_data(
109
+ self,
110
+ user_input: str,
111
+ teacher_response: str,
112
+ student_response: str = None,
113
+ quality_score: float = 0.0
114
+ ) -> int:
115
+ conn = self._get_conn()
116
+ cursor = conn.cursor()
117
+
118
+ cursor.execute('''
119
+ INSERT INTO distillation_data
120
+ (user_input, teacher_response, student_response, quality_score)
121
+ VALUES (?, ?, ?, ?)
122
+ ''', (user_input, teacher_response, student_response, quality_score))
123
+
124
+ data_id = cursor.lastrowid
125
+ conn.commit()
126
+ conn.close()
127
+
128
+ return data_id
129
+
130
+ def get_unused_distillation_data(self, limit: int = 500) -> List[Dict]:
131
+ """Get teacher responses not yet used for training"""
132
+ conn = self._get_conn()
133
+ cursor = conn.cursor()
134
+
135
+ cursor.execute('''
136
+ SELECT id, user_input, teacher_response
137
+ FROM distillation_data
138
+ WHERE used_for_training = 0
139
+ ORDER BY timestamp DESC
140
+ LIMIT ?
141
+ ''', (limit,))
142
+
143
+ rows = cursor.fetchall()
144
+ conn.close()
145
+
146
+ return [dict(row) for row in rows]
147
+
148
+ def mark_distillation_used(self, ids: List[int]):
149
+ """Mark distillation data as used for training"""
150
+ conn = self._get_conn()
151
+ cursor = conn.cursor()
152
+
153
+ placeholders = ",".join("?" * len(ids))
154
+ cursor.execute(f'''
155
+ UPDATE distillation_data
156
+ SET used_for_training = 1
157
+ WHERE id IN ({placeholders})
158
+ ''', ids)
159
+
160
+ conn.commit()
161
+ conn.close()
162
+
163
+ def get_distillation_count(self) -> Dict:
164
+ """Get count of distillation data"""
165
+ conn = self._get_conn()
166
+ cursor = conn.cursor()
167
+
168
+ cursor.execute('SELECT COUNT(*) FROM distillation_data')
169
+ total = cursor.fetchone()[0]
170
+
171
+ cursor.execute('SELECT COUNT(*) FROM distillation_data WHERE used_for_training = 0')
172
+ unused = cursor.fetchone()[0]
173
+
174
+ cursor.execute('SELECT COUNT(*) FROM distillation_data WHERE used_for_training = 1')
175
+ used = cursor.fetchone()[0]
176
+
177
+ conn.close()
178
+
179
+ return {"total": total, "unused": unused, "used": used}
180
+
181
+ # ===== Stats =====
182
  def get_stats(self) -> Dict:
183
  conn = self._get_conn()
184
  cursor = conn.cursor()
185
+
186
  cursor.execute('SELECT COUNT(*) FROM conversations')
187
  total = cursor.fetchone()[0]
188
+
189
  cursor.execute('SELECT COUNT(*) FROM conversations WHERE feedback > 0')
190
  positive = cursor.fetchone()[0]
191
+
192
  cursor.execute('SELECT COUNT(*) FROM conversations WHERE feedback < 0')
193
  negative = cursor.fetchone()[0]
194
+
195
+ distill = self.get_distillation_count()
196
+
197
+ conn.close()
198
+
199
+ return {
200
+ "total": total,
201
+ "positive": positive,
202
+ "negative": negative,
203
+ "distillation_total": distill["total"],
204
+ "distillation_unused": distill["unused"],
205
+ }
206
+
207
+ def save_training_history(
208
+ self,
209
+ training_type: str,
210
+ samples_used: int,
211
+ epochs: int,
212
+ final_loss: float
213
+ ):
214
+ conn = self._get_conn()
215
+ cursor = conn.cursor()
216
+
217
+ cursor.execute('''
218
+ INSERT INTO training_history (training_type, samples_used, epochs, final_loss)
219
+ VALUES (?, ?, ?, ?)
220
+ ''', (training_type, samples_used, epochs, final_loss))
221
+
222
+ conn.commit()
223
  conn.close()
 
 
224
 
225
 
226
  db = VedaDatabase()