File size: 8,471 Bytes
e9fc73c
f930370
e9fc73c
 
6ddf006
e9fc73c
 
 
 
 
 
6ddf006
e9fc73c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7c8206
e9fc73c
 
 
 
 
 
 
 
629d674
e9fc73c
 
 
 
 
6ddf006
e9fc73c
 
6ddf006
e9fc73c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ddf006
e9fc73c
 
 
 
 
 
6ddf006
e9fc73c
6ddf006
 
e9fc73c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ddf006
e9fc73c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ddf006
e9fc73c
 
 
 
 
6ddf006
 
e9fc73c
 
 
 
 
 
 
 
 
 
 
 
 
6ddf006
e9fc73c
 
 
6ddf006
 
 
e9fc73c
 
586ee73
e9fc73c
6ddf006
 
 
 
e9fc73c
 
 
6ddf006
 
e9fc73c
6ddf006
 
 
 
e9fc73c
 
 
 
 
 
 
 
 
 
 
 
6ddf006
b7c8206
e9fc73c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
# enhanced_app.py
import os
import time
import logging
import asyncio
from concurrent.futures import ThreadPoolExecutor
import threading
from typing import List, Dict, Any

import numpy as np
import torch
from cachetools import TTLCache
from sentence_transformers import SentenceTransformer
import gradio as gr

# ---------- 設定 ----------
MODEL_NAME = os.environ.get("MODEL_NAME", "summerstars/MARK-Embedding")
BATCH_SIZE = int(os.environ.get("BATCH_SIZE", 32))
CACHE_MAXSIZE = int(os.environ.get("CACHE_MAXSIZE", 2000))
CACHE_TTL = int(os.environ.get("CACHE_TTL", 300))  # seconds
MAX_WORKERS = int(os.environ.get("MAX_WORKERS", max(4, (os.cpu_count() or 1) * 2)))
GRADIO_QUEUE_SIZE = int(os.environ.get("GRADIO_QUEUE_SIZE", 500))
GRADIO_CONCURRENCY = int(os.environ.get("GRADIO_CONCURRENCY", 20))
_LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO")

# ---------- ロガー ----------
logging.basicConfig(level=_LOG_LEVEL, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("enhanced_app")

# ---------- グローバル ----------
model: SentenceTransformer = None
executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
embedding_cache = TTLCache(maxsize=CACHE_MAXSIZE, ttl=CACHE_TTL)
cache_lock = threading.RLock()

# Optional: Faiss を使った高速近似検索(あれば有効化できます)
HAS_FAISS = False
try:
    import faiss  # type: ignore
    HAS_FAISS = True
    logger.info("faiss detected: ANN indexing available")
except Exception:
    logger.info("faiss not available: ANN indexing disabled")

# ---------- ユーティリティ ----------
def _normalize_rows(arr: np.ndarray, eps: float = 1e-12) -> np.ndarray:
    norms = np.linalg.norm(arr, axis=1, keepdims=True)
    norms[norms == 0] = eps
    return arr / norms

def _encode_texts_sync(texts: List[str], batch_size: int = BATCH_SIZE) -> np.ndarray:
    """同期版のエンコード(スレッドプールで呼ぶ) -> 正規化済み numpy.ndarray"""
    global model
    if model is None:
        raise RuntimeError("Model not loaded")
    # SentenceTransformer.encode with convert_to_tensor=False returns np.ndarray
    with torch.no_grad():
        emb = model.encode(texts, convert_to_tensor=False, batch_size=batch_size, show_progress_bar=False)
    arr = np.asarray(emb, dtype=np.float32)
    arr = _normalize_rows(arr)
    return arr

async def encode_texts_async(texts: List[str]) -> np.ndarray:
    loop = asyncio.get_running_loop()
    return await loop.run_in_executor(executor, _encode_texts_sync, texts, BATCH_SIZE)

# ---------- モデルロード ----------
async def load_model() -> bool:
    """非同期でモデルをロード(起動時に一度だけ呼ぶ)"""
    global model
    device = "cuda" if torch.cuda.is_available() else "cpu"
    logger.info(f"Loading model '{MODEL_NAME}' on device: {device} ...")
    try:
        # SentenceTransformer は device 引数で内部的にデバイスを設定します
        model = SentenceTransformer(MODEL_NAME, device=device, trust_remote_code=True)
        model.eval()
        # 注意: GPU メモリを節約したい場合は外部で 8-bit/bitsandbytes を検討
        logger.info("Model loaded successfully.")
        return True
    except Exception as e:
        logger.exception("Failed to load model")
        return False

# ---------- キャッシュ付き埋め込み取得 ----------
async def get_embeddings(texts: str) -> Dict[str, List[float]]:
    """
    改行区切りのテキスト群を受け取り、キャッシュを活用して埋め込みを返す。
    返却形式: { "文1": [f32,...], "文2": [...] }
    """
    if not texts or not texts.strip():
        return {}
    sentences = [s.strip() for s in texts.split("\n") if s.strip()]
    if not sentences:
        return {}

    uncached = []
    results: Dict[str, np.ndarray] = {}

    # キャッシュチェック(スレッド安全)
    with cache_lock:
        for s in sentences:
            v = embedding_cache.get(s)
            if v is not None:
                results[s] = v  # numpy array
            else:
                uncached.append(s)

    # 未キャッシュの文を一括で計算
    if uncached:
        logger.debug(f"Encoding {len(uncached)} uncached sentences (batch_size={BATCH_SIZE})")
        arr = await encode_texts_async(uncached)  # 正規化済み np.ndarray (N, D)
        with cache_lock:
            for s, vec in zip(uncached, arr):
                embedding_cache[s] = vec  # 保持は numpy array
                results[s] = vec

    # JSON シリアライズのため list に変換して返却
    return {s: results[s].tolist() for s in sentences}

# ---------- 単一文の埋め込み取得(内部ユーティリティ) ----------
async def _get_single_embedding(text: str) -> np.ndarray:
    if not text:
        raise ValueError("empty text")
    with cache_lock:
        v = embedding_cache.get(text)
        if v is not None:
            return v
    arr = await encode_texts_async([text])
    vec = arr[0]
    with cache_lock:
        embedding_cache[text] = vec
    return vec

# ---------- 類似度計算 ----------
async def calculate_similarity(text1: str, text2: str) -> float:
    """
    2つの文章のコサイン類似度を返す (float、-1.0〜1.0)
    非同期で動作し、内部でキャッシュを活用します。
    """
    start = time.time()
    if not text1 or not text2:
        return 0.0
    try:
        emb1 = await _get_single_embedding(text1)
        emb2 = await _get_single_embedding(text2)
        # どちらも正規化済み -> dot product が cosine similarity
        sim = float(np.dot(emb1, emb2))
        elapsed = (time.time() - start) * 1000
        logger.debug(f"Similarity computed in {elapsed:.1f}ms -> {sim:.6f}")
        return round(sim, 6)
    except Exception as e:
        logger.exception("Error in calculate_similarity")
        return 0.0

# ---------- Gradio アプリ ----------
async def main():
    ok = await load_model()
    if not ok:
        logger.error("Model load failed - exiting")
        return

    with gr.Blocks(theme=gr.themes.Default()) as demo:
        gr.Markdown(f"# 高速テキスト類似度 API — `{MODEL_NAME}`\n高速化・キャッシュ・非同期実行を適用しています。")

        with gr.Tab("コンテキスト理解"):
            gr.Markdown("## 2つの文章を入力して類似度を計算します")
            with gr.Row():
                text_input1 = gr.Textbox(label="文章1", lines=3, placeholder="例: 今日の天気は晴れです。")
                text_input2 = gr.Textbox(label="文章2", lines=3, placeholder="例: 今日は良い天気ですね。")
            calculate_button = gr.Button("類似度を計算", variant="primary")
            similarity_output = gr.Number(label="コサイン類似度スコア")
            calculate_button.click(fn=calculate_similarity, inputs=[text_input1, text_input2], outputs=similarity_output)

        with gr.Tab("埋め込みベクトル生成"):
            gr.Markdown("## テキスト(改行区切り)を入力して埋め込みベクトルを生成します")
            texts_input = gr.Textbox(label="テキスト入力 (1行に1つの文章)", lines=6, placeholder="犬が公園を走っている。\n猫が窓際で日向ぼっこをしている。")
            generate_button = gr.Button("ベクトルを生成", variant="primary")
            embeddings_output = gr.JSON(label="生成された埋め込みベクトル")
            generate_button.click(fn=get_embeddings, inputs=texts_input, outputs=embeddings_output)

        gr.Markdown("### ヘルスチェック")
        gr.Textbox(value="OK", label="サービス状態 (静的表示)")

        # API: calculate_similarity を HTTP API として公開(Gradio のバージョン依存)
        try:
            demo.api(calculate_similarity, api_name="calculate_similarity")
        except Exception:
            # 古い/新しい Gradio で API 作成方法が違うことがあるため安全にフォールバック
            logger.debug("demo.api not supported in this environment; skipping automatic API registration")

    demo.queue(max_size=GRADIO_QUEUE_SIZE, default_concurrency_limit=GRADIO_CONCURRENCY)
    demo.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=True, max_threads=MAX_WORKERS)

if __name__ == "__main__":
    asyncio.run(main())