Spaces:
Sleeping
Sleeping
| # enhanced_app.py | |
| import os | |
| import time | |
| import logging | |
| import asyncio | |
| from concurrent.futures import ThreadPoolExecutor | |
| import threading | |
| from typing import List, Dict, Any | |
| import numpy as np | |
| import torch | |
| from cachetools import TTLCache | |
| from sentence_transformers import SentenceTransformer | |
| import gradio as gr | |
| # ---------- 設定 ---------- | |
| MODEL_NAME = os.environ.get("MODEL_NAME", "summerstars/MARK-Embedding") | |
| BATCH_SIZE = int(os.environ.get("BATCH_SIZE", 32)) | |
| CACHE_MAXSIZE = int(os.environ.get("CACHE_MAXSIZE", 2000)) | |
| CACHE_TTL = int(os.environ.get("CACHE_TTL", 300)) # seconds | |
| MAX_WORKERS = int(os.environ.get("MAX_WORKERS", max(4, (os.cpu_count() or 1) * 2))) | |
| GRADIO_QUEUE_SIZE = int(os.environ.get("GRADIO_QUEUE_SIZE", 500)) | |
| GRADIO_CONCURRENCY = int(os.environ.get("GRADIO_CONCURRENCY", 20)) | |
| _LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO") | |
| # ---------- ロガー ---------- | |
| logging.basicConfig(level=_LOG_LEVEL, format="%(asctime)s %(levelname)s %(message)s") | |
| logger = logging.getLogger("enhanced_app") | |
| # ---------- グローバル ---------- | |
| model: SentenceTransformer = None | |
| executor = ThreadPoolExecutor(max_workers=MAX_WORKERS) | |
| embedding_cache = TTLCache(maxsize=CACHE_MAXSIZE, ttl=CACHE_TTL) | |
| cache_lock = threading.RLock() | |
| # Optional: Faiss を使った高速近似検索(あれば有効化できます) | |
| HAS_FAISS = False | |
| try: | |
| import faiss # type: ignore | |
| HAS_FAISS = True | |
| logger.info("faiss detected: ANN indexing available") | |
| except Exception: | |
| logger.info("faiss not available: ANN indexing disabled") | |
| # ---------- ユーティリティ ---------- | |
| def _normalize_rows(arr: np.ndarray, eps: float = 1e-12) -> np.ndarray: | |
| norms = np.linalg.norm(arr, axis=1, keepdims=True) | |
| norms[norms == 0] = eps | |
| return arr / norms | |
| def _encode_texts_sync(texts: List[str], batch_size: int = BATCH_SIZE) -> np.ndarray: | |
| """同期版のエンコード(スレッドプールで呼ぶ) -> 正規化済み numpy.ndarray""" | |
| global model | |
| if model is None: | |
| raise RuntimeError("Model not loaded") | |
| # SentenceTransformer.encode with convert_to_tensor=False returns np.ndarray | |
| with torch.no_grad(): | |
| emb = model.encode(texts, convert_to_tensor=False, batch_size=batch_size, show_progress_bar=False) | |
| arr = np.asarray(emb, dtype=np.float32) | |
| arr = _normalize_rows(arr) | |
| return arr | |
| async def encode_texts_async(texts: List[str]) -> np.ndarray: | |
| loop = asyncio.get_running_loop() | |
| return await loop.run_in_executor(executor, _encode_texts_sync, texts, BATCH_SIZE) | |
| # ---------- モデルロード ---------- | |
| async def load_model() -> bool: | |
| """非同期でモデルをロード(起動時に一度だけ呼ぶ)""" | |
| global model | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"Loading model '{MODEL_NAME}' on device: {device} ...") | |
| try: | |
| # SentenceTransformer は device 引数で内部的にデバイスを設定します | |
| model = SentenceTransformer(MODEL_NAME, device=device, trust_remote_code=True) | |
| model.eval() | |
| # 注意: GPU メモリを節約したい場合は外部で 8-bit/bitsandbytes を検討 | |
| logger.info("Model loaded successfully.") | |
| return True | |
| except Exception as e: | |
| logger.exception("Failed to load model") | |
| return False | |
| # ---------- キャッシュ付き埋め込み取得 ---------- | |
| async def get_embeddings(texts: str) -> Dict[str, List[float]]: | |
| """ | |
| 改行区切りのテキスト群を受け取り、キャッシュを活用して埋め込みを返す。 | |
| 返却形式: { "文1": [f32,...], "文2": [...] } | |
| """ | |
| if not texts or not texts.strip(): | |
| return {} | |
| sentences = [s.strip() for s in texts.split("\n") if s.strip()] | |
| if not sentences: | |
| return {} | |
| uncached = [] | |
| results: Dict[str, np.ndarray] = {} | |
| # キャッシュチェック(スレッド安全) | |
| with cache_lock: | |
| for s in sentences: | |
| v = embedding_cache.get(s) | |
| if v is not None: | |
| results[s] = v # numpy array | |
| else: | |
| uncached.append(s) | |
| # 未キャッシュの文を一括で計算 | |
| if uncached: | |
| logger.debug(f"Encoding {len(uncached)} uncached sentences (batch_size={BATCH_SIZE})") | |
| arr = await encode_texts_async(uncached) # 正規化済み np.ndarray (N, D) | |
| with cache_lock: | |
| for s, vec in zip(uncached, arr): | |
| embedding_cache[s] = vec # 保持は numpy array | |
| results[s] = vec | |
| # JSON シリアライズのため list に変換して返却 | |
| return {s: results[s].tolist() for s in sentences} | |
| # ---------- 単一文の埋め込み取得(内部ユーティリティ) ---------- | |
| async def _get_single_embedding(text: str) -> np.ndarray: | |
| if not text: | |
| raise ValueError("empty text") | |
| with cache_lock: | |
| v = embedding_cache.get(text) | |
| if v is not None: | |
| return v | |
| arr = await encode_texts_async([text]) | |
| vec = arr[0] | |
| with cache_lock: | |
| embedding_cache[text] = vec | |
| return vec | |
| # ---------- 類似度計算 ---------- | |
| async def calculate_similarity(text1: str, text2: str) -> float: | |
| """ | |
| 2つの文章のコサイン類似度を返す (float、-1.0〜1.0) | |
| 非同期で動作し、内部でキャッシュを活用します。 | |
| """ | |
| start = time.time() | |
| if not text1 or not text2: | |
| return 0.0 | |
| try: | |
| emb1 = await _get_single_embedding(text1) | |
| emb2 = await _get_single_embedding(text2) | |
| # どちらも正規化済み -> dot product が cosine similarity | |
| sim = float(np.dot(emb1, emb2)) | |
| elapsed = (time.time() - start) * 1000 | |
| logger.debug(f"Similarity computed in {elapsed:.1f}ms -> {sim:.6f}") | |
| return round(sim, 6) | |
| except Exception as e: | |
| logger.exception("Error in calculate_similarity") | |
| return 0.0 | |
| # ---------- Gradio アプリ ---------- | |
| async def main(): | |
| ok = await load_model() | |
| if not ok: | |
| logger.error("Model load failed - exiting") | |
| return | |
| with gr.Blocks(theme=gr.themes.Default()) as demo: | |
| gr.Markdown(f"# 高速テキスト類似度 API — `{MODEL_NAME}`\n高速化・キャッシュ・非同期実行を適用しています。") | |
| with gr.Tab("コンテキスト理解"): | |
| gr.Markdown("## 2つの文章を入力して類似度を計算します") | |
| with gr.Row(): | |
| text_input1 = gr.Textbox(label="文章1", lines=3, placeholder="例: 今日の天気は晴れです。") | |
| text_input2 = gr.Textbox(label="文章2", lines=3, placeholder="例: 今日は良い天気ですね。") | |
| calculate_button = gr.Button("類似度を計算", variant="primary") | |
| similarity_output = gr.Number(label="コサイン類似度スコア") | |
| calculate_button.click(fn=calculate_similarity, inputs=[text_input1, text_input2], outputs=similarity_output) | |
| with gr.Tab("埋め込みベクトル生成"): | |
| gr.Markdown("## テキスト(改行区切り)を入力して埋め込みベクトルを生成します") | |
| texts_input = gr.Textbox(label="テキスト入力 (1行に1つの文章)", lines=6, placeholder="犬が公園を走っている。\n猫が窓際で日向ぼっこをしている。") | |
| generate_button = gr.Button("ベクトルを生成", variant="primary") | |
| embeddings_output = gr.JSON(label="生成された埋め込みベクトル") | |
| generate_button.click(fn=get_embeddings, inputs=texts_input, outputs=embeddings_output) | |
| gr.Markdown("### ヘルスチェック") | |
| gr.Textbox(value="OK", label="サービス状態 (静的表示)") | |
| # API: calculate_similarity を HTTP API として公開(Gradio のバージョン依存) | |
| try: | |
| demo.api(calculate_similarity, api_name="calculate_similarity") | |
| except Exception: | |
| # 古い/新しい Gradio で API 作成方法が違うことがあるため安全にフォールバック | |
| logger.debug("demo.api not supported in this environment; skipping automatic API registration") | |
| demo.queue(max_size=GRADIO_QUEUE_SIZE, default_concurrency_limit=GRADIO_CONCURRENCY) | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=True, max_threads=MAX_WORKERS) | |
| if __name__ == "__main__": | |
| asyncio.run(main()) | |