import streamlit as st import json import os import boto3 import psycopg2 from contextlib import contextmanager from pgvector.psycopg2 import register_vector from utils import json_safe from openai import OpenAI import dotenv dotenv.load_dotenv() _openai_client = OpenAI( base_url="https://api.tokenfactory.nebius.com/v1/", api_key=os.environ.get("NEBIUS_API_KEY"), ) _region = os.environ.get("AWS_REGION") _secret_arn = os.environ.get("RDS_SECRET_ARN") _dbname = os.environ.get("RDS_DB_NAME") _sm_client = boto3.client("secretsmanager", region_name=_region) _secret_value = _sm_client.get_secret_value(SecretId=_secret_arn) _secret_dict = json.loads(_secret_value["SecretString"]) _reader_host = os.environ.get("RDS_READER_HOST") _writer_host = os.environ.get("RDS_WRITER_HOST") def embed_query(query: str): response = _openai_client.embeddings.create( model="Qwen/Qwen3-Embedding-8B", input=query ) return response.data[0].embedding @st.cache_data(ttl=60*60*24*7) def cached_embed(query): return embed_query(query) @contextmanager def get_rds_conn(host: str): with psycopg2.connect( host=host, port=int(_secret_dict.get("port", 5432)), dbname=_dbname or _secret_dict.get("dbname"), user=_secret_dict["username"], password=_secret_dict["password"], sslmode="require", ) as conn: register_vector(conn) yield conn @contextmanager def reader_conn(): with get_rds_conn(_reader_host) as conn: yield conn @contextmanager def writer_conn(): with get_rds_conn(_writer_host) as conn: yield conn @st.cache_data(ttl=60*60*24*7) def load_sources(): with reader_conn() as conn, conn.cursor() as cur: cur.execute(""" SELECT array_agg(DISTINCT source ORDER BY source) FROM theorem_search_qwen8b; """) return cur.fetchone()[0] or [] @st.cache_data(ttl=60*60*24*7) def load_source_caps(): with reader_conn() as conn, conn.cursor() as cur: cur.execute(""" SELECT jsonb_object_agg( source, jsonb_build_object('has_metadata', has_metadata) ) FROM ( SELECT source, bool_or(has_metadata) AS has_metadata FROM theorem_search_qwen8b GROUP BY source ) s; """) return cur.fetchone()[0] or {} @st.cache_data(ttl=60*60*24*7) def load_authors(): with reader_conn() as conn, conn.cursor() as cur: cur.execute(""" SELECT jsonb_object_agg(source, authors) FROM ( SELECT source, array_agg(DISTINCT author ORDER BY author) AS authors FROM ( SELECT source, unnest(authors) AS author FROM theorem_search_qwen8b WHERE authors IS NOT NULL ) t GROUP BY source ) s; """) return cur.fetchone()[0] or {} @st.cache_data(ttl=60*60*24*7) def load_tags(): with reader_conn() as conn, conn.cursor() as cur: cur.execute(""" SELECT jsonb_object_agg(source, tags) FROM ( SELECT source, array_agg(DISTINCT primary_category ORDER BY primary_category) AS tags FROM theorem_search_qwen8b WHERE primary_category IS NOT NULL GROUP BY source ) s; """) return cur.fetchone()[0] or {} @st.cache_data(ttl=60*60*24*7) def load_theorem_count(): with reader_conn() as conn: with conn.cursor() as cur: cur.execute("SELECT COUNT(*) FROM theorem_search_qwen8b;") count = cur.fetchone()[0] return count def row_to_dict(cursor, row): return {desc[0]: row[i] for i, desc in enumerate(cursor.description)} def insert_feedback(payload: dict): with writer_conn() as conn: sql = """ INSERT INTO feedback ( feedback, query, url, theorem_name, authors, types, tags, sources, paper_filter, year_range, citation_range, citation_weight, include_unknown_citations, top_k ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s); """ with conn.cursor() as cur: cur.execute(sql, ( payload["feedback"], payload["query"], payload["url"], payload["theorem_name"], payload["authors"], payload["types"], payload["tags"], payload["sources"], payload["paper_filter"], payload["year_range"], payload["citation_range"], payload["citation_weight"], payload["include_unknown_citations"], payload["top_k"], )) def insert_query(query: str, filters: dict): with writer_conn() as conn: with conn.cursor() as cur: cur.execute( """ INSERT INTO public.queries (query, sources, filters) VALUES (%s, %s, %s); """, ( query, filters["sources"], json.dumps(json_safe(filters)), ), ) def fetch_candidate_ids( query_vec, citation_weight, top_k, where_sql, where_params, ): with reader_conn() as conn, conn.cursor() as cur: selected_sources = where_params.get("sources", []) if not selected_sources: return [] per_source_multiplier = 3 ef_search = max(80, top_k * 4) cur.execute("SET LOCAL hnsw.ef_search = %s;", (ef_search,)) cur.execute("SET LOCAL hnsw.iterative_scan = 'relaxed_order';") all_rows = [] for source in selected_sources: sql = f""" WITH ann AS ( SELECT slogan_id, citations, embedding FROM theorem_search_qwen8b WHERE source = %(source)s ORDER BY (binary_quantize(embedding)::bit(4096)) <~> binary_quantize(%(query_vec_ann)s::vector(4096))::bit(4096) LIMIT %(per_source_limit)s ) SELECT slogan_id, (1.0 - (embedding <=> %(query_vec_rerank)s::vector(4096))) AS similarity, (1.0 - (embedding <=> %(query_vec_rerank)s::vector(4096))) + %(citation_weight)s * CASE WHEN citations > 0 THEN ln(citations::float) ELSE 0 END AS score FROM ann; """ params = { "source": source, "query_vec_ann": query_vec, "query_vec_rerank": query_vec, "citation_weight": citation_weight, "per_source_limit": top_k * per_source_multiplier, } cur.execute(sql, params) all_rows.extend(cur.fetchall()) if not all_rows: return [] all_rows.sort(key=lambda x: x[2], reverse=True) return all_rows[:top_k] def fetch_full_rows(slogan_rows): if not slogan_rows: return [] slogan_ids = [r[0] for r in slogan_rows] score_map = {r[0]: (r[1], r[2]) for r in slogan_rows} with reader_conn() as conn, conn.cursor() as cur: sql = """ SELECT slogan_id, theorem_id, paper_id, theorem_name, theorem_body, theorem_slogan, theorem_type, title, authors, link, year, journal_published, primary_category, categories, citations, source, has_metadata FROM theorem_search_qwen8b WHERE slogan_id = ANY(%(ids)s) ORDER BY array_position(%(ids)s, slogan_id); """ cur.execute(sql, {"ids": slogan_ids}) rows = cur.fetchall() return [ { **row_to_dict(cur, row), "similarity": score_map[row[0]][0], "score": score_map[row[0]][1], } for row in rows ] def fetch_results( query_vec, citation_weight, top_k, where_sql, where_params ): candidates = fetch_candidate_ids( query_vec=query_vec, citation_weight=citation_weight, top_k=top_k, where_sql=where_sql, where_params=where_params, ) return fetch_full_rows(candidates)