siteagent / shared.py
ginigen-ai's picture
Update shared.py
446d594 verified
import os
import re as _re
import json
import asyncio as _asyncio
import aiosqlite
GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "")
FAL_KEY = os.environ.get("FAL_KEY", "")
HF_TOKEN = os.environ.get("HF_TOKEN", "")
ADMIN_EMAIL = "arxivgpt@gmail.com"
SA_BACKUP_REPO = "ginigen-ai/siteagent-db"
DATASET_ID = "ginigen-ai/siteagent"
CLIP_BUCKET = "ginigen-ai/siteagent"
CLIP_MAX_BYTES = 1 * 1024 * 1024 * 1024
NAVER_CLIENT_ID = os.environ.get("NAVER_CLIENT_ID", "")
NAVER_CLIENT_SECRET = os.environ.get("NAVER_CLIENT_SECRET", "")
GROQ_URL = "https://api.groq.com/openai/v1/chat/completions"
def _get_db_path():
for p in ["/data/siteagent.db", "./siteagent.db"]:
d = os.path.dirname(p)
if d and os.path.isdir(d):
return p
if not d:
return p
return "./siteagent.db"
SA_DB_PATH = _get_db_path()
_db_lock = None
def _get_lock():
global _db_lock
if _db_lock is None: _db_lock = _asyncio.Lock()
return _db_lock
_write_count = 0
_write_queue = None
async def _init_write_queue():
global _write_queue
if _write_queue is None:
_write_queue = _asyncio.Queue(maxsize=500)
_asyncio.create_task(_db_write_worker())
async def _db_write_worker():
while True:
try:
sql, params = await _write_queue.get()
try:
if params:
params = tuple(_sanitize_text(p) if isinstance(p, str) else p for p in params)
async with aiosqlite.connect(SA_DB_PATH, timeout=10.0) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute(sql, params or ())
await db.commit()
global _write_count
_write_count += 1
if _write_count % 100 == 0:
_asyncio.create_task(_sa_backup_db())
except Exception as e:
print(f"[db-worker] {e}")
finally:
_write_queue.task_done()
except _asyncio.CancelledError:
break
except Exception as e:
print(f"[db-worker-fatal] {e}")
await _asyncio.sleep(0.1)
def _db_enqueue(sql, params=None):
if _write_queue is None: return
try:
_write_queue.put_nowait((sql, params))
except _asyncio.QueueFull:
try:
_write_queue.get_nowait()
_write_queue.put_nowait((sql, params))
except: pass
async def _db_write(sql, params=None):
_db_enqueue(sql, params)
async def _db_write_sync(sql, params=None):
async with aiosqlite.connect(SA_DB_PATH, timeout=10.0) as db:
await db.execute("PRAGMA journal_mode=WAL")
cursor = await db.execute(sql, params or ())
await db.commit()
return cursor
async def _db_read(sql, params=None):
async with aiosqlite.connect(SA_DB_PATH, timeout=10.0) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute(sql, params or ())
return await cursor.fetchall()
async def _db_read_one(sql, params=None):
async with aiosqlite.connect(SA_DB_PATH, timeout=10.0) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute(sql, params or ())
return await cursor.fetchone()
async def _sa_backup_db():
if not HF_TOKEN or not os.path.exists(SA_DB_PATH): return
try:
import sqlite3
conn = sqlite3.connect(SA_DB_PATH)
result = conn.execute("PRAGMA integrity_check").fetchone()
conn.close()
if result[0] != "ok":
print(f"⚠️ DB integrity issue: {result[0]}")
return
from datetime import datetime, timezone
ts = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
from huggingface_hub import HfApi
api = HfApi(token=HF_TOKEN)
try: api.create_repo(repo_id=SA_BACKUP_REPO, repo_type="dataset", private=True, exist_ok=True)
except: pass
api.upload_file(path_or_fileobj=SA_DB_PATH, path_in_repo=f"backup/siteagent_{ts}.db", repo_id=SA_BACKUP_REPO, repo_type="dataset")
api.upload_file(path_or_fileobj=SA_DB_PATH, path_in_repo="latest/siteagent.db", repo_id=SA_BACKUP_REPO, repo_type="dataset")
print(f"✅ [SA-backup] Hub backup: {ts}")
except Exception as e:
print(f"⚠️ [SA-backup] {e}")
def _sanitize_text(text):
if not isinstance(text, str):
return text
clean = _re.sub(r'[\ud800-\udfff]', '', text)
try:
clean.encode('utf-8')
return clean
except UnicodeEncodeError:
return clean.encode('utf-8', 'replace').decode('utf-8')
def _sanitize_messages(messages):
if not messages:
return messages
clean = []
for msg in messages:
m = dict(msg)
if isinstance(m.get("content"), str):
m["content"] = _sanitize_text(m["content"])
elif isinstance(m.get("content"), list):
m["content"] = [
{**item, "text": _sanitize_text(item["text"])} if isinstance(item, dict) and "text" in item else item
for item in m["content"]
]
clean.append(m)
return clean
def _flush_table(tbl, out):
if len(tbl) < 2:
out.extend(tbl)
return
headers = [c.strip() for c in tbl[0].strip('|').split('|')]
for i, tr in enumerate(tbl):
cells = [c.strip() for c in tr.strip('|').split('|')]
if i == 0:
out.append('**' + ' · '.join(c for c in cells if c) + '**')
elif i == 1 and _re.match(r'^[\s:\-|]+$', tr):
continue
else:
parts = []
for j, cell in enumerate(cells):
if cell:
if j < len(headers) and headers[j] and headers[j] != cell:
parts.append(f'**{headers[j]}**: {cell}')
else:
parts.append(cell)
out.append('• ' + ' / '.join(parts))
def _strip_md_table(text):
if not isinstance(text, str) or '|' not in text:
return text
lines = text.split('\n')
out = []
tbl = []
in_tbl = False
for line in lines:
if _re.match(r'^\s*\|.+\|\s*$', line):
in_tbl = True
tbl.append(line)
else:
if in_tbl:
_flush_table(tbl, out)
tbl = []
in_tbl = False
out.append(line)
if in_tbl:
_flush_table(tbl, out)
return '\n'.join(out)