Embedding / main.py
summerstars's picture
Update main.py
e9fc73c verified
# enhanced_app.py
import os
import time
import logging
import asyncio
from concurrent.futures import ThreadPoolExecutor
import threading
from typing import List, Dict, Any
import numpy as np
import torch
from cachetools import TTLCache
from sentence_transformers import SentenceTransformer
import gradio as gr
# ---------- 設定 ----------
MODEL_NAME = os.environ.get("MODEL_NAME", "summerstars/MARK-Embedding")
BATCH_SIZE = int(os.environ.get("BATCH_SIZE", 32))
CACHE_MAXSIZE = int(os.environ.get("CACHE_MAXSIZE", 2000))
CACHE_TTL = int(os.environ.get("CACHE_TTL", 300)) # seconds
MAX_WORKERS = int(os.environ.get("MAX_WORKERS", max(4, (os.cpu_count() or 1) * 2)))
GRADIO_QUEUE_SIZE = int(os.environ.get("GRADIO_QUEUE_SIZE", 500))
GRADIO_CONCURRENCY = int(os.environ.get("GRADIO_CONCURRENCY", 20))
_LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO")
# ---------- ロガー ----------
logging.basicConfig(level=_LOG_LEVEL, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("enhanced_app")
# ---------- グローバル ----------
model: SentenceTransformer = None
executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
embedding_cache = TTLCache(maxsize=CACHE_MAXSIZE, ttl=CACHE_TTL)
cache_lock = threading.RLock()
# Optional: Faiss を使った高速近似検索(あれば有効化できます)
HAS_FAISS = False
try:
import faiss # type: ignore
HAS_FAISS = True
logger.info("faiss detected: ANN indexing available")
except Exception:
logger.info("faiss not available: ANN indexing disabled")
# ---------- ユーティリティ ----------
def _normalize_rows(arr: np.ndarray, eps: float = 1e-12) -> np.ndarray:
norms = np.linalg.norm(arr, axis=1, keepdims=True)
norms[norms == 0] = eps
return arr / norms
def _encode_texts_sync(texts: List[str], batch_size: int = BATCH_SIZE) -> np.ndarray:
"""同期版のエンコード(スレッドプールで呼ぶ) -> 正規化済み numpy.ndarray"""
global model
if model is None:
raise RuntimeError("Model not loaded")
# SentenceTransformer.encode with convert_to_tensor=False returns np.ndarray
with torch.no_grad():
emb = model.encode(texts, convert_to_tensor=False, batch_size=batch_size, show_progress_bar=False)
arr = np.asarray(emb, dtype=np.float32)
arr = _normalize_rows(arr)
return arr
async def encode_texts_async(texts: List[str]) -> np.ndarray:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(executor, _encode_texts_sync, texts, BATCH_SIZE)
# ---------- モデルロード ----------
async def load_model() -> bool:
"""非同期でモデルをロード(起動時に一度だけ呼ぶ)"""
global model
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Loading model '{MODEL_NAME}' on device: {device} ...")
try:
# SentenceTransformer は device 引数で内部的にデバイスを設定します
model = SentenceTransformer(MODEL_NAME, device=device, trust_remote_code=True)
model.eval()
# 注意: GPU メモリを節約したい場合は外部で 8-bit/bitsandbytes を検討
logger.info("Model loaded successfully.")
return True
except Exception as e:
logger.exception("Failed to load model")
return False
# ---------- キャッシュ付き埋め込み取得 ----------
async def get_embeddings(texts: str) -> Dict[str, List[float]]:
"""
改行区切りのテキスト群を受け取り、キャッシュを活用して埋め込みを返す。
返却形式: { "文1": [f32,...], "文2": [...] }
"""
if not texts or not texts.strip():
return {}
sentences = [s.strip() for s in texts.split("\n") if s.strip()]
if not sentences:
return {}
uncached = []
results: Dict[str, np.ndarray] = {}
# キャッシュチェック(スレッド安全)
with cache_lock:
for s in sentences:
v = embedding_cache.get(s)
if v is not None:
results[s] = v # numpy array
else:
uncached.append(s)
# 未キャッシュの文を一括で計算
if uncached:
logger.debug(f"Encoding {len(uncached)} uncached sentences (batch_size={BATCH_SIZE})")
arr = await encode_texts_async(uncached) # 正規化済み np.ndarray (N, D)
with cache_lock:
for s, vec in zip(uncached, arr):
embedding_cache[s] = vec # 保持は numpy array
results[s] = vec
# JSON シリアライズのため list に変換して返却
return {s: results[s].tolist() for s in sentences}
# ---------- 単一文の埋め込み取得(内部ユーティリティ) ----------
async def _get_single_embedding(text: str) -> np.ndarray:
if not text:
raise ValueError("empty text")
with cache_lock:
v = embedding_cache.get(text)
if v is not None:
return v
arr = await encode_texts_async([text])
vec = arr[0]
with cache_lock:
embedding_cache[text] = vec
return vec
# ---------- 類似度計算 ----------
async def calculate_similarity(text1: str, text2: str) -> float:
"""
2つの文章のコサイン類似度を返す (float、-1.0〜1.0)
非同期で動作し、内部でキャッシュを活用します。
"""
start = time.time()
if not text1 or not text2:
return 0.0
try:
emb1 = await _get_single_embedding(text1)
emb2 = await _get_single_embedding(text2)
# どちらも正規化済み -> dot product が cosine similarity
sim = float(np.dot(emb1, emb2))
elapsed = (time.time() - start) * 1000
logger.debug(f"Similarity computed in {elapsed:.1f}ms -> {sim:.6f}")
return round(sim, 6)
except Exception as e:
logger.exception("Error in calculate_similarity")
return 0.0
# ---------- Gradio アプリ ----------
async def main():
ok = await load_model()
if not ok:
logger.error("Model load failed - exiting")
return
with gr.Blocks(theme=gr.themes.Default()) as demo:
gr.Markdown(f"# 高速テキスト類似度 API — `{MODEL_NAME}`\n高速化・キャッシュ・非同期実行を適用しています。")
with gr.Tab("コンテキスト理解"):
gr.Markdown("## 2つの文章を入力して類似度を計算します")
with gr.Row():
text_input1 = gr.Textbox(label="文章1", lines=3, placeholder="例: 今日の天気は晴れです。")
text_input2 = gr.Textbox(label="文章2", lines=3, placeholder="例: 今日は良い天気ですね。")
calculate_button = gr.Button("類似度を計算", variant="primary")
similarity_output = gr.Number(label="コサイン類似度スコア")
calculate_button.click(fn=calculate_similarity, inputs=[text_input1, text_input2], outputs=similarity_output)
with gr.Tab("埋め込みベクトル生成"):
gr.Markdown("## テキスト(改行区切り)を入力して埋め込みベクトルを生成します")
texts_input = gr.Textbox(label="テキスト入力 (1行に1つの文章)", lines=6, placeholder="犬が公園を走っている。\n猫が窓際で日向ぼっこをしている。")
generate_button = gr.Button("ベクトルを生成", variant="primary")
embeddings_output = gr.JSON(label="生成された埋め込みベクトル")
generate_button.click(fn=get_embeddings, inputs=texts_input, outputs=embeddings_output)
gr.Markdown("### ヘルスチェック")
gr.Textbox(value="OK", label="サービス状態 (静的表示)")
# API: calculate_similarity を HTTP API として公開(Gradio のバージョン依存)
try:
demo.api(calculate_similarity, api_name="calculate_similarity")
except Exception:
# 古い/新しい Gradio で API 作成方法が違うことがあるため安全にフォールバック
logger.debug("demo.api not supported in this environment; skipping automatic API registration")
demo.queue(max_size=GRADIO_QUEUE_SIZE, default_concurrency_limit=GRADIO_CONCURRENCY)
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=True, max_threads=MAX_WORKERS)
if __name__ == "__main__":
asyncio.run(main())