Spaces:
Paused
Paused
| import os, json, time, re, asyncio | |
| from fastapi import APIRouter, Request | |
| from fastapi.responses import JSONResponse | |
| from shared import GROQ_API_KEY, GROQ_URL, _db_read, _db_read_one, _db_write, _db_write_sync, _sanitize_text, SA_DB_PATH | |
| router = APIRouter() | |
| # ═══ DB 초기화 (shared.py의 init_db에서 호출) ═══ | |
| MEMORY_TABLES_SQL = """ | |
| CREATE TABLE IF NOT EXISTS conversations ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| email TEXT NOT NULL, | |
| role TEXT NOT NULL, | |
| content TEXT NOT NULL, | |
| page_url TEXT DEFAULT '', | |
| page_title TEXT DEFAULT '', | |
| created_at REAL NOT NULL | |
| ); | |
| CREATE INDEX IF NOT EXISTS idx_conv_email ON conversations(email, created_at DESC); | |
| CREATE TABLE IF NOT EXISTS user_memory ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| email TEXT NOT NULL, | |
| category TEXT NOT NULL DEFAULT 'general', | |
| key TEXT NOT NULL, | |
| value TEXT NOT NULL, | |
| source TEXT DEFAULT '', | |
| importance INTEGER DEFAULT 5, | |
| created_at REAL NOT NULL, | |
| updated_at REAL NOT NULL | |
| ); | |
| CREATE INDEX IF NOT EXISTS idx_mem_email ON user_memory(email, importance DESC); | |
| CREATE UNIQUE INDEX IF NOT EXISTS idx_mem_unique ON user_memory(email, category, key); | |
| CREATE TABLE IF NOT EXISTS task_history ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| email TEXT NOT NULL, | |
| task_type TEXT NOT NULL, | |
| task_input TEXT NOT NULL, | |
| task_output TEXT DEFAULT '', | |
| page_url TEXT DEFAULT '', | |
| page_title TEXT DEFAULT '', | |
| status TEXT DEFAULT 'completed', | |
| created_at REAL NOT NULL | |
| ); | |
| CREATE INDEX IF NOT EXISTS idx_task_email ON task_history(email, created_at DESC); | |
| """ | |
| async def init_memory_tables(): | |
| """메모리 테이블 초기화 — app.py startup에서 호출""" | |
| import aiosqlite | |
| async with aiosqlite.connect(SA_DB_PATH, timeout=10.0) as db: | |
| await db.executescript(MEMORY_TABLES_SQL) | |
| await db.commit() | |
| print("✅ Memory tables initialized") | |
| # ═══ 대화 이력 저장/조회 ═══ | |
| async def save_conversation(email, role, content, page_url="", page_title=""): | |
| """대화 한 턴 저장""" | |
| content = _sanitize_text(content)[:10000] | |
| await _db_write( | |
| "INSERT INTO conversations (email,role,content,page_url,page_title,created_at) VALUES (?,?,?,?,?,?)", | |
| (email.lower(), role, content, page_url[:500], page_title[:200], time.time()) | |
| ) | |
| async def get_recent_conversations(email, limit=10): | |
| """최근 대화 N턴 조회""" | |
| rows = await _db_read( | |
| "SELECT role, content, page_url, created_at FROM conversations WHERE email=? ORDER BY created_at DESC LIMIT ?", | |
| (email.lower(), limit) | |
| ) | |
| result = [{"role": r["role"], "content": r["content"], "page_url": r["page_url"]} for r in reversed(rows)] | |
| return result | |
| # ═══ 장기기억 저장/조회/검색 ═══ | |
| async def save_memory(email, category, key, value, source="", importance=5): | |
| """장기기억 저장 (upsert)""" | |
| now = time.time() | |
| await _db_write( | |
| """INSERT INTO user_memory (email,category,key,value,source,importance,created_at,updated_at) | |
| VALUES (?,?,?,?,?,?,?,?) | |
| ON CONFLICT(email,category,key) DO UPDATE SET value=excluded.value, source=excluded.source, | |
| importance=MAX(importance, excluded.importance), updated_at=excluded.updated_at""", | |
| (email.lower(), category, _sanitize_text(key)[:200], _sanitize_text(value)[:2000], source[:200], importance, now, now) | |
| ) | |
| async def get_all_memories(email, limit=50): | |
| """전체 장기기억 조회 (중요도+최신순)""" | |
| rows = await _db_read( | |
| "SELECT id, category, key, value, importance, updated_at FROM user_memory WHERE email=? ORDER BY importance DESC, updated_at DESC LIMIT ?", | |
| (email.lower(), limit) | |
| ) | |
| return [dict(r) for r in rows] | |
| async def search_memories(email, query, limit=10): | |
| """키워드 기반 기억 검색""" | |
| keywords = re.findall(r'[\w가-힣]+', query.lower()) | |
| if not keywords: | |
| return [] | |
| conditions = " OR ".join(["(LOWER(key) LIKE ? OR LOWER(value) LIKE ?)" for _ in keywords]) | |
| params = [] | |
| for kw in keywords: | |
| params.extend([f"%{kw}%", f"%{kw}%"]) | |
| rows = await _db_read( | |
| f"SELECT id, category, key, value, importance FROM user_memory WHERE email=? AND ({conditions}) ORDER BY importance DESC LIMIT ?", | |
| [email.lower()] + params + [limit] | |
| ) | |
| return [dict(r) for r in rows] | |
| async def delete_memory(email, memory_id): | |
| """기억 삭제""" | |
| await _db_write_sync("DELETE FROM user_memory WHERE id=? AND email=?", (memory_id, email.lower())) | |
| # ═══ 작업 이력 저장/조회 ═══ | |
| async def save_task(email, task_type, task_input, task_output="", page_url="", page_title="", status="completed"): | |
| """작업 이력 저장""" | |
| await _db_write( | |
| "INSERT INTO task_history (email,task_type,task_input,task_output,page_url,page_title,status,created_at) VALUES (?,?,?,?,?,?,?,?)", | |
| (email.lower(), task_type[:50], _sanitize_text(task_input)[:2000], _sanitize_text(task_output)[:2000], | |
| page_url[:500], page_title[:200], status, time.time()) | |
| ) | |
| async def get_recent_tasks(email, limit=10): | |
| """최근 작업 이력""" | |
| rows = await _db_read( | |
| "SELECT task_type, task_input, task_output, page_url, status, created_at FROM task_history WHERE email=? ORDER BY created_at DESC LIMIT ?", | |
| (email.lower(), limit) | |
| ) | |
| return [dict(r) for r in rows] | |
| # ═══ LLM 기반 자동 기억 추출 (BackgroundTask) ═══ | |
| async def extract_memories_from_conversation(email, user_msg, assistant_msg, page_url=""): | |
| """대화에서 장기기억을 자동 추출 (백그라운드)""" | |
| if not GROQ_API_KEY or not user_msg: | |
| return | |
| prompt = f"""다음 대화에서 사용자에 대해 기억할 만한 핵심 정보를 추출하세요. | |
| 사용자 메시지: {user_msg[:1000]} | |
| AI 응답: {(assistant_msg or '')[:500]} | |
| 페이지: {page_url[:200]} | |
| JSON 배열로만 응답 (정보 없으면 빈 배열 []): | |
| [ | |
| {{"category": "personal|preference|work|skill|location|habit|request", "key": "간단한 키", "value": "값", "importance": 1~10}} | |
| ] | |
| 규칙: | |
| - 명확한 사실만 추출 (추측 금지) | |
| - 일반적인 질문(날씨, 검색 등)에서는 추출하지 않음 | |
| - category: personal(이름/나이), preference(선호), work(직업/회사), skill(기술), location(위치), habit(습관), request(반복요청) | |
| - importance: 이름/직업=9, 선호=7, 일반=5 | |
| - 최대 3개까지""" | |
| try: | |
| import httpx | |
| async with httpx.AsyncClient(timeout=30.0) as client: | |
| resp = await client.post( | |
| GROQ_URL, | |
| headers={"Authorization": f"Bearer {GROQ_API_KEY}", "Content-Type": "application/json"}, | |
| json={"model": "openai/gpt-oss-120b", "messages": [{"role": "user", "content": prompt}], | |
| "max_completion_tokens": 500, "temperature": 0.1, "stream": False} | |
| ) | |
| if resp.status_code != 200: | |
| return | |
| text = resp.json().get("choices", [{}])[0].get("message", {}).get("content", "") | |
| m = re.search(r'\[[\s\S]*\]', text) | |
| if not m: | |
| return | |
| items = json.loads(m.group()) | |
| for item in items[:3]: | |
| cat = item.get("category", "general") | |
| key = item.get("key", "") | |
| value = item.get("value", "") | |
| imp = min(10, max(1, int(item.get("importance", 5)))) | |
| if key and value: | |
| await save_memory(email, cat, key, value, source=page_url[:200], importance=imp) | |
| print(f"[memory] ✅ {email}: {cat}/{key}={value} (imp={imp})") | |
| except Exception as e: | |
| print(f"[memory] extract error: {e}") | |
| # ═══ 컨텍스트 빌더: 기억 + 대화이력 + 작업이력 → 시스템 프롬프트 ═══ | |
| async def build_memory_context(email, current_message): | |
| """현재 질문에 맞는 기억 컨텍스트 구성""" | |
| parts = [] | |
| # 1. 장기기억 (관련 기억 검색 + 중요 기억 상위) | |
| relevant = await search_memories(email, current_message, limit=5) | |
| top_memories = await get_all_memories(email, limit=10) | |
| # 중복 제거 후 합산 | |
| seen_ids = set() | |
| all_mems = [] | |
| for m in relevant + top_memories: | |
| if m["id"] not in seen_ids: | |
| seen_ids.add(m["id"]) | |
| all_mems.append(m) | |
| all_mems = all_mems[:10] | |
| if all_mems: | |
| mem_lines = [] | |
| for m in all_mems: | |
| mem_lines.append(f"- [{m['category']}] {m['key']}: {m['value']}") | |
| parts.append("【사용자 기억 정보】\n" + "\n".join(mem_lines)) | |
| # 2. 최근 작업 이력 | |
| tasks = await get_recent_tasks(email, limit=5) | |
| if tasks: | |
| task_lines = [] | |
| for t in tasks: | |
| from datetime import datetime | |
| ts = datetime.fromtimestamp(t["created_at"]).strftime("%m/%d %H:%M") | |
| task_lines.append(f"- [{ts}] {t['task_type']}: {t['task_input'][:80]}") | |
| parts.append("【최근 작업 이력】\n" + "\n".join(task_lines)) | |
| # 3. 최근 대화 이력 | |
| convs = await get_recent_conversations(email, limit=6) | |
| if convs: | |
| conv_lines = [] | |
| for c in convs: | |
| role_kr = "사용자" if c["role"] == "user" else "AI" | |
| conv_lines.append(f"{role_kr}: {c['content'][:150]}") | |
| parts.append("【최근 대화】\n" + "\n".join(conv_lines)) | |
| if not parts: | |
| return "" | |
| return "\n\n".join(parts) + "\n\n위 정보를 참고하여 개인화된 응답을 해주세요. 기억 정보를 자연스럽게 활용하되, 기억이 있다는 것을 노골적으로 드러내지 마세요." | |
| # ═══ API 엔드포인트 ═══ | |
| async def memory_list(email: str = ""): | |
| email = email.strip().lower() | |
| if not email: | |
| return JSONResponse({"error": "email required"}, status_code=400) | |
| memories = await get_all_memories(email, limit=50) | |
| tasks = await get_recent_tasks(email, limit=20) | |
| conv_count_row = await _db_read_one( | |
| "SELECT COUNT(*) as cnt FROM conversations WHERE email=?", (email,) | |
| ) | |
| conv_count = conv_count_row["cnt"] if conv_count_row else 0 | |
| return { | |
| "ok": True, | |
| "memories": memories, | |
| "tasks": tasks, | |
| "conversation_count": conv_count, | |
| "memory_count": len(memories) | |
| } | |
| async def memory_delete(request: Request): | |
| try: | |
| body = await request.json() | |
| except: | |
| return JSONResponse({"error": "invalid json"}, status_code=400) | |
| email = (body.get("email") or "").strip().lower() | |
| memory_id = body.get("id") | |
| if not email or not memory_id: | |
| return JSONResponse({"error": "email and id required"}, status_code=400) | |
| await delete_memory(email, memory_id) | |
| return {"ok": True, "deleted": memory_id} | |
| async def memory_clear(request: Request): | |
| try: | |
| body = await request.json() | |
| except: | |
| return JSONResponse({"error": "invalid json"}, status_code=400) | |
| email = (body.get("email") or "").strip().lower() | |
| target = body.get("target", "all") | |
| if not email: | |
| return JSONResponse({"error": "email required"}, status_code=400) | |
| if target in ("all", "memory"): | |
| await _db_write_sync("DELETE FROM user_memory WHERE email=?", (email,)) | |
| if target in ("all", "conversations"): | |
| await _db_write_sync("DELETE FROM conversations WHERE email=?", (email,)) | |
| if target in ("all", "tasks"): | |
| await _db_write_sync("DELETE FROM task_history WHERE email=?", (email,)) | |
| return {"ok": True, "cleared": target} | |
| async def memory_add(request: Request): | |
| try: | |
| body = await request.json() | |
| except: | |
| return JSONResponse({"error": "invalid json"}, status_code=400) | |
| email = (body.get("email") or "").strip().lower() | |
| category = (body.get("category") or "general")[:50] | |
| key = _sanitize_text(body.get("key") or "")[:200] | |
| value = _sanitize_text(body.get("value") or "")[:2000] | |
| if not email or not key or not value: | |
| return JSONResponse({"error": "email, key, value required"}, status_code=400) | |
| await save_memory(email, category, key, value, source="manual", importance=8) | |
| return {"ok": True} | |
| async def task_log(request: Request): | |
| """모든 기능(요약/번역/검색/MARL 등)에서 작업 이력 저장""" | |
| try: | |
| body = await request.json() | |
| except: | |
| return JSONResponse({"error": "invalid json"}, status_code=400) | |
| email = (body.get("email") or "").strip().lower() | |
| task_type = (body.get("type") or "unknown")[:50] | |
| task_input = _sanitize_text(body.get("input") or "")[:2000] | |
| task_output = _sanitize_text(body.get("output") or "")[:2000] | |
| page_url = (body.get("url") or "")[:500] | |
| page_title = (body.get("title") or "")[:200] | |
| if not email or not task_input: | |
| return JSONResponse({"error": "email and input required"}, status_code=400) | |
| await save_task(email, task_type, task_input, task_output, page_url, page_title) | |
| # 중요한 작업은 기억으로도 추출 | |
| if len(task_input) > 20: | |
| import asyncio | |
| asyncio.create_task(extract_memories_from_conversation(email, task_input, task_output, page_url)) | |
| return {"ok": True} | |