File size: 9,355 Bytes
fafcc25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
#!/usr/bin/env python3
"""
Conversation Memory System - SQLite-backed conversation tracker with text search and Markdown export.
Designed for offline use with optional sentence-transformers for semantic search.
"""

import sqlite3
import json
import os
from pathlib import Path
from datetime import datetime
import gradio as gr

DB_PATH = Path("/data/conversations.db")
DB_PATH.parent.mkdir(parents=True, exist_ok=True)

# Try optional semantic search
try:
    from sentence_transformers import SentenceTransformer
    import numpy as np
    EMBED_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
    HAS_EMBEDDINGS = True
except ImportError:
    HAS_EMBEDDINGS = False

class ConversationMemory:
    def __init__(self, db_path: str):
        self.db_path = db_path
        self._init_db()
    
    def _init_db(self):
        with sqlite3.connect(self.db_path) as conn:
            conn.executescript("""
                CREATE TABLE IF NOT EXISTS threads (
                    thread_id TEXT PRIMARY KEY,
                    title TEXT,
                    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                    updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
                );
                CREATE TABLE IF NOT EXISTS messages (
                    id INTEGER PRIMARY KEY AUTOINCREMENT,
                    thread_id TEXT NOT NULL,
                    role TEXT NOT NULL CHECK(role IN ('user','assistant','system','agent-zero')),
                    content TEXT NOT NULL,
                    source TEXT DEFAULT 'user',
                    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                    FOREIGN KEY (thread_id) REFERENCES threads(thread_id)
                );
                CREATE INDEX IF NOT EXISTS idx_messages_thread ON messages(thread_id);
                CREATE INDEX IF NOT EXISTS idx_threads_updated ON threads(updated_at);
            """)
            if HAS_EMBEDDINGS:
                conn.execute("""
                    CREATE TABLE IF NOT EXISTS embeddings (
                        message_id INTEGER PRIMARY KEY,
                        embedding BLOB,
                        FOREIGN KEY (message_id) REFERENCES messages(id)
                    )
                """)
            conn.commit()
    
    def save_message(self, role: str, content: str, thread_id: str = None, source: str = "user"):
        if thread_id is None:
            thread_id = f"thread-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
        with sqlite3.connect(self.db_path) as conn:
            conn.execute("INSERT OR IGNORE INTO threads (thread_id, title) VALUES (?, ?)",
                        (thread_id, thread_id))
            c = conn.execute("INSERT INTO messages (thread_id, role, content, source) VALUES (?,?,?,?)",
                        (thread_id, role, content, source))
            conn.execute("UPDATE threads SET updated_at = CURRENT_TIMESTAMP WHERE thread_id = ?", (thread_id,))
            conn.commit()
            msg_id = c.lastrowid
            if HAS_EMBEDDINGS:
                emb = EMBED_MODEL.encode(content[:512])
                conn.execute("INSERT INTO embeddings (message_id, embedding) VALUES (?,?)",
                            (msg_id, emb.tobytes()))
                conn.commit()
        return thread_id
    
    def search_text(self, query: str, limit: int = 20):
        with sqlite3.connect(self.db_path) as conn:
            return conn.execute(
                "SELECT thread_id, role, content, created_at FROM messages WHERE content LIKE ? ORDER BY created_at DESC LIMIT ?",
                (f"%{query}%", limit)
            ).fetchall()
    
    def search_semantic(self, query: str, limit: int = 10):
        if not HAS_EMBEDDINGS:
            return [("error", "system", "Install sentence-transformers for semantic search", "")]
        q_emb = EMBED_MODEL.encode(query)
        results = []
        with sqlite3.connect(self.db_path) as conn:
            rows = conn.execute("""
                SELECT m.thread_id, m.role, m.content, m.created_at, e.embedding
                FROM messages m JOIN embeddings e ON m.id = e.message_id
                ORDER BY m.created_at DESC LIMIT 500
            """).fetchall()
        for row in rows:
            emb = np.frombuffer(row[4], dtype=np.float32)
            score = np.dot(q_emb, emb) / (np.linalg.norm(q_emb) * np.linalg.norm(emb) + 1e-8)
            results.append((score, row[0], row[1], row[2][:500], row[3]))
        results.sort(key=lambda x: x[0], reverse=True)
        return [(r[1], r[2], r[3], r[4]) for r in results[:limit]]
    
    def list_threads(self):
        with sqlite3.connect(self.db_path) as conn:
            return conn.execute(
                "SELECT thread_id, title, created_at, updated_at, (SELECT COUNT(*) FROM messages WHERE thread_id = t.thread_id) as msg_count FROM threads t ORDER BY updated_at DESC LIMIT 50"
            ).fetchall()
    
    def get_thread(self, thread_id: str):
        with sqlite3.connect(self.db_path) as conn:
            return conn.execute(
                "SELECT role, content, source, created_at FROM messages WHERE thread_id = ? ORDER BY created_at",
                (thread_id,)
            ).fetchall()
    
    def export_markdown(self, thread_id: str) -> str:
        msgs = self.get_thread(thread_id)
        md = f"# Thread: {thread_id}\n\n"
        for role, content, source, ts in msgs:
            md += f"## {role} ({source}) - {ts}\n\n{content}\n\n---\n\n"
        return md
    
    def export_all(self) -> str:
        threads = self.list_threads()
        md = "# All Conversations\n\n"
        for tid, title, created, updated, count in threads:
            md += f"## {title} ({count} msgs)\n"
            md += f"Created: {created} | Updated: {updated}\n\n"
            for role, content, source, ts in self.get_thread(tid):
                md += f"### {role} ({source}) - {ts}\n\n{content[:500]}...\n\n"
            md += "---\n\n"
        return md

# Initialize
memory = ConversationMemory(str(DB_PATH))

# Gradio UI
with gr.Blocks(title="Conversation Memory", theme=gr.themes.Soft()) as demo:
    gr.Markdown("""# ๐Ÿ’พ Conversation Memory System
    **SQLite-backed** | **Text + Semantic Search** | **Markdown Export** | **Fully Offline**
    """)
    
    with gr.Tabs():
        with gr.TabItem("๐Ÿ’ฌ Save"):
            role = gr.Dropdown(["user", "assistant", "system", "agent-zero"], label="Role", value="user")
            content = gr.Textbox(label="Message", lines=4, placeholder="Enter conversation message...")
            thread_id = gr.Textbox(label="Thread ID (optional)", placeholder="auto-generated if blank")
            source = gr.Textbox(label="Source", value="user")
            save_btn = gr.Button("Save Message")
            save_status = gr.Textbox(label="Status")
            
            def save(role, content, thread_id, source):
                tid = memory.save_message(role, content, thread_id or None, source)
                return f"Saved to thread: {tid}"
            
            save_btn.click(save, [role, content, thread_id, source], save_status)
        
        with gr.TabItem("๐Ÿ” Search"):
            search_query = gr.Textbox(label="Search Query")
            search_mode = gr.Radio(["Text", "Semantic"], label="Mode", value="Text")
            search_btn = gr.Button("Search")
            search_results = gr.Dataframe(
                headers=["Thread", "Role", "Content", "Time"],
                label="Results"
            )
            
            def do_search(query, mode):
                if mode == "Text":
                    results = memory.search_text(query)
                else:
                    results = memory.search_semantic(query)
                return [[r[0], r[1], r[2][:300], r[3]] for r in results]
            
            search_btn.click(do_search, [search_query, search_mode], search_results)
        
        with gr.TabItem("๐Ÿ“‹ Threads"):
            threads_list = gr.Dataframe(
                headers=["Thread ID", "Title", "Created", "Updated", "Messages"],
                label="All Threads"
            )
            refresh_btn = gr.Button("Refresh")
            refresh_btn.click(lambda: memory.list_threads(), None, threads_list)
            
            thread_detail_id = gr.Textbox(label="Thread ID")
            show_thread_btn = gr.Button("Show Thread")
            thread_content = gr.Markdown(label="Thread Content")
            
            def show_thread(tid):
                if not tid: return "Enter a thread ID"
                msgs = memory.get_thread(tid)
                return "\n\n".join(f"**{r[0]}** ({r[2]}) - {r[3]}\n\n{r[1]}" for r in msgs)
            
            show_thread_btn.click(show_thread, thread_detail_id, thread_content)
        
        with gr.TabItem("๐Ÿ“ค Export"):
            export_thread_id = gr.Textbox(label="Thread ID (or 'all' for everything)")
            export_btn = gr.Button("Export to Markdown")
            export_output = gr.Markdown(label="Exported Markdown")
            
            def do_export(tid):
                if tid == "all":
                    return memory.export_all()
                return memory.export_markdown(tid)
            
            export_btn.click(do_export, export_thread_id, export_output)

demo.queue().launch(server_name="0.0.0.0", server_port=7860)