summerstars commited on
Commit
e9fc73c
·
verified ·
1 Parent(s): 4e9ff49

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +163 -100
main.py CHANGED
@@ -1,132 +1,195 @@
1
- import gradio as gr
2
- from sentence_transformers import SentenceTransformer, util
3
- import torch
4
- import csv
5
  import os
 
 
6
  import asyncio
7
- from typing import Tuple, Dict, Any
8
- import functools
 
 
 
 
9
  from cachetools import TTLCache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- # --- 1. 設定項目 ---
12
- # 読み込むCSVファイル名 (類似度計算用にオプションで使用可能だが、今回は不要のためコメントアウト)
13
- # CSV_FILE_PATH = "subset.csv"
14
- # 埋め込みモデル (日本語対応)
15
- MODEL_NAME = "summerstars/MARK-Embedding"
 
 
 
16
 
17
- # --- 2. グローバル変数の定義 ---
18
- model = None
19
- # キャッシュ: 最近のテキスト埋め込みをメモリに保存 (TTL: 300秒)
20
- embedding_cache = TTLCache(maxsize=1000, ttl=300)
 
21
 
22
- # --- 3. モデルの非同期ロード ---
23
- async def load_model():
24
- """モデルを非同期でロードする関数"""
25
  global model
26
- print(f"モデル '{MODEL_NAME}' をロードしています...")
27
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  try:
29
- # モデルの最適化: FP16精度でロード (GPU使用時のみ)
30
- if device == 'cuda':
31
- model = SentenceTransformer(MODEL_NAME, device=device, trust_remote_code=True)
32
- model.half() # FP16量化でメモリ効率と速度向上
33
- print("モデルをFP16精度で最適化しました。")
34
- else:
35
- model = SentenceTransformer(MODEL_NAME, device=device)
36
- print(f"モデルのロードが完了しました。デバイス: {device}")
37
  except Exception as e:
38
- print(f"モデルのロード中にエラーが発生しました: {e}")
39
  return False
40
- return True
41
 
42
- # --- 4. APIのコア機能となる非同期関数を定義 ---
43
- @functools.lru_cache(maxsize=128)
44
- def get_cached_embedding(text: str) -> torch.Tensor:
45
- """テキス��の埋め込みをキャッシュ付きで取得 (LRUキャッシュで高速化)"""
46
- return model.encode(text, convert_to_tensor=True, device=model.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  async def calculate_similarity(text1: str, text2: str) -> float:
49
- """2つのテキストの類似度を非同期で計算する関数 (キャッシュ活用)"""
 
 
 
 
50
  if not text1 or not text2:
51
  return 0.0
52
- # キャッシュから埋め込みを取得 (なければ計算)
53
- emb1 = get_cached_embedding(text1)
54
- emb2 = get_cached_embedding(text2)
55
- cosine_scores = util.cos_sim(emb1, emb2)
56
- return round(cosine_scores.item(), 4)
57
-
58
- async def get_embeddings(texts: str) -> Dict[str, list]:
59
- """テキストの埋め込みベクトルを非同期で生成する関数 (バッチ処理とキャッシュ活用)"""
60
- if not texts.strip():
61
- return {}
62
- sentences = [s.strip() for s in texts.strip().split('\n') if s.strip()]
63
- if not sentences:
64
- return {}
65
-
66
- # キャッシュチェックと未キャッシュ部分の抽出
67
- uncached_sentences = []
68
- cached_embeddings = {}
69
- for s in sentences:
70
- if s in embedding_cache:
71
- cached_embeddings[s] = embedding_cache[s]
72
- else:
73
- uncached_sentences.append(s)
74
-
75
- # 未キャッシュ部分をバッチで計算 (バッチサイズ=32で効率化)
76
- if uncached_sentences:
77
- uncached_embeddings = model.encode(uncached_sentences, convert_to_tensor=True, device=model.device, batch_size=32)
78
- for s, emb in zip(uncached_sentences, uncached_embeddings):
79
- embedding_cache[s] = emb.cpu().numpy() # NumPyに変換してキャッシュ (メモリ節約)
80
- cached_embeddings[s] = emb
81
-
82
- # 結果をリストに変換
83
- result = {s: emb.tolist() for s, emb in cached_embeddings.items()}
84
- return result
85
-
86
- # --- 5. Gradioインターフェースの構築 ---
87
  async def main():
88
- # モデルをロード
89
- success = await load_model()
90
- if not success:
91
- print("アプリケーションの起動を中止します。")
92
  return
93
 
94
  with gr.Blocks(theme=gr.themes.Default()) as demo:
95
- gr.Markdown(
96
- f"""
97
- # テキスト類似性計算API
98
- `{MODEL_NAME}` を使用し、文章の類似度計算および埋め込みベクトル生成を行います。
99
- 危険度���定機能は削除され、高性能化により高速・効率的な処理を実現しています。
100
- """
101
- )
102
-
103
  with gr.Tab("コンテキスト理解"):
104
- gr.Markdown("## 2つの文章を入力して類似度を計算します(コンテキスト理解機能)")
105
  with gr.Row():
106
  text_input1 = gr.Textbox(label="文章1", lines=3, placeholder="例: 今日の天気は晴れです。")
107
  text_input2 = gr.Textbox(label="文章2", lines=3, placeholder="例: 今日は良い天気ですね。")
108
  calculate_button = gr.Button("類似度を計算", variant="primary")
109
- similarity_output = gr.Number(label="コサイン類似度スコア (値が高いほど類似)")
110
- calculate_button.click(
111
- fn=calculate_similarity, inputs=[text_input1, text_input2], outputs=similarity_output
112
- )
113
-
114
  with gr.Tab("埋め込みベクトル生成"):
115
  gr.Markdown("## テキスト(改行区切り)を入力して埋め込みベクトルを生成します")
116
- with gr.Row():
117
- texts_input = gr.Textbox(label="テキスト入力 (1行に1つの文章)", lines=5, placeholder="犬が公園を走っている。\n猫が窓際で日向ぼっこをしている。")
118
  generate_button = gr.Button("ベクトルを生成", variant="primary")
119
  embeddings_output = gr.JSON(label="生成された埋め込みベクトル")
120
  generate_button.click(fn=get_embeddings, inputs=texts_input, outputs=embeddings_output)
121
-
122
- # カスタムAPIエンドポイントの追加: 2つのテキストの類似度計算を非同期で公開
123
- gr.api(calculate_similarity, api_name="calculate_similarity")
124
 
125
- # --- 6. Gradioアプリの起動(高性能化: キュー拡張と同時実行制限増加) ---
126
- demo.queue(max_size=500, default_concurrency_limit=20) # キューサイズ拡大と同時実行制限増加
127
- demo.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=True,
128
- max_threads=50) # スレッド数増加で並列処理強化
 
 
 
 
 
 
 
 
129
 
130
- # --- 7. アプリケーションのエントリーポイント ---
131
  if __name__ == "__main__":
132
- asyncio.run(main())
 
1
+ # enhanced_app.py
 
 
 
2
  import os
3
+ import time
4
+ import logging
5
  import asyncio
6
+ from concurrent.futures import ThreadPoolExecutor
7
+ import threading
8
+ from typing import List, Dict, Any
9
+
10
+ import numpy as np
11
+ import torch
12
  from cachetools import TTLCache
13
+ from sentence_transformers import SentenceTransformer
14
+ import gradio as gr
15
+
16
+ # ---------- 設定 ----------
17
+ MODEL_NAME = os.environ.get("MODEL_NAME", "summerstars/MARK-Embedding")
18
+ BATCH_SIZE = int(os.environ.get("BATCH_SIZE", 32))
19
+ CACHE_MAXSIZE = int(os.environ.get("CACHE_MAXSIZE", 2000))
20
+ CACHE_TTL = int(os.environ.get("CACHE_TTL", 300)) # seconds
21
+ MAX_WORKERS = int(os.environ.get("MAX_WORKERS", max(4, (os.cpu_count() or 1) * 2)))
22
+ GRADIO_QUEUE_SIZE = int(os.environ.get("GRADIO_QUEUE_SIZE", 500))
23
+ GRADIO_CONCURRENCY = int(os.environ.get("GRADIO_CONCURRENCY", 20))
24
+ _LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO")
25
+
26
+ # ---------- ロガー ----------
27
+ logging.basicConfig(level=_LOG_LEVEL, format="%(asctime)s %(levelname)s %(message)s")
28
+ logger = logging.getLogger("enhanced_app")
29
+
30
+ # ---------- グローバル ----------
31
+ model: SentenceTransformer = None
32
+ executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
33
+ embedding_cache = TTLCache(maxsize=CACHE_MAXSIZE, ttl=CACHE_TTL)
34
+ cache_lock = threading.RLock()
35
 
36
+ # Optional: Faiss を使った高速近似検索(あれば有効化できます)
37
+ HAS_FAISS = False
38
+ try:
39
+ import faiss # type: ignore
40
+ HAS_FAISS = True
41
+ logger.info("faiss detected: ANN indexing available")
42
+ except Exception:
43
+ logger.info("faiss not available: ANN indexing disabled")
44
 
45
+ # ---------- ユーティリティ ----------
46
+ def _normalize_rows(arr: np.ndarray, eps: float = 1e-12) -> np.ndarray:
47
+ norms = np.linalg.norm(arr, axis=1, keepdims=True)
48
+ norms[norms == 0] = eps
49
+ return arr / norms
50
 
51
+ def _encode_texts_sync(texts: List[str], batch_size: int = BATCH_SIZE) -> np.ndarray:
52
+ """同期版のエンコード(スレッドプールで呼ぶ) -> 正規化済み numpy.ndarray"""
 
53
  global model
54
+ if model is None:
55
+ raise RuntimeError("Model not loaded")
56
+ # SentenceTransformer.encode with convert_to_tensor=False returns np.ndarray
57
+ with torch.no_grad():
58
+ emb = model.encode(texts, convert_to_tensor=False, batch_size=batch_size, show_progress_bar=False)
59
+ arr = np.asarray(emb, dtype=np.float32)
60
+ arr = _normalize_rows(arr)
61
+ return arr
62
+
63
+ async def encode_texts_async(texts: List[str]) -> np.ndarray:
64
+ loop = asyncio.get_running_loop()
65
+ return await loop.run_in_executor(executor, _encode_texts_sync, texts, BATCH_SIZE)
66
+
67
+ # ---------- モデルロード ----------
68
+ async def load_model() -> bool:
69
+ """非同期でモデルをロード(起動時に一度だけ呼ぶ)"""
70
+ global model
71
+ device = "cuda" if torch.cuda.is_available() else "cpu"
72
+ logger.info(f"Loading model '{MODEL_NAME}' on device: {device} ...")
73
  try:
74
+ # SentenceTransformer device 引数で内部的にデバイスを設定します
75
+ model = SentenceTransformer(MODEL_NAME, device=device, trust_remote_code=True)
76
+ model.eval()
77
+ # 注意: GPU メモリを節約したい場合は外部で 8-bit/bitsandbytes を検討
78
+ logger.info("Model loaded successfully.")
79
+ return True
 
 
80
  except Exception as e:
81
+ logger.exception("Failed to load model")
82
  return False
 
83
 
84
+ # ---------- キャッシュ付き埋め込み取得 ----------
85
+ async def get_embeddings(texts: str) -> Dict[str, List[float]]:
86
+ """
87
+ 改行区切りのテキスト群を受け取り、キャッシュを活用して埋め込みを返す。
88
+ 返却形式: { "文1": [f32,...], "文2": [...] }
89
+ """
90
+ if not texts or not texts.strip():
91
+ return {}
92
+ sentences = [s.strip() for s in texts.split("\n") if s.strip()]
93
+ if not sentences:
94
+ return {}
95
+
96
+ uncached = []
97
+ results: Dict[str, np.ndarray] = {}
98
+
99
+ # キャッシュチェック(スレッド安全)
100
+ with cache_lock:
101
+ for s in sentences:
102
+ v = embedding_cache.get(s)
103
+ if v is not None:
104
+ results[s] = v # numpy array
105
+ else:
106
+ uncached.append(s)
107
 
108
+ # 未キャッシュの文を一括で計算
109
+ if uncached:
110
+ logger.debug(f"Encoding {len(uncached)} uncached sentences (batch_size={BATCH_SIZE})")
111
+ arr = await encode_texts_async(uncached) # 正規化済み np.ndarray (N, D)
112
+ with cache_lock:
113
+ for s, vec in zip(uncached, arr):
114
+ embedding_cache[s] = vec # 保持は numpy array
115
+ results[s] = vec
116
+
117
+ # JSON シリアライズのため list に変換して返却
118
+ return {s: results[s].tolist() for s in sentences}
119
+
120
+ # ---------- 単一文の埋め込み取得(内部ユーティリティ) ----------
121
+ async def _get_single_embedding(text: str) -> np.ndarray:
122
+ if not text:
123
+ raise ValueError("empty text")
124
+ with cache_lock:
125
+ v = embedding_cache.get(text)
126
+ if v is not None:
127
+ return v
128
+ arr = await encode_texts_async([text])
129
+ vec = arr[0]
130
+ with cache_lock:
131
+ embedding_cache[text] = vec
132
+ return vec
133
+
134
+ # ---------- 類似度計算 ----------
135
  async def calculate_similarity(text1: str, text2: str) -> float:
136
+ """
137
+ 2つの文章のコサイン類似度を返す (float、-1.0〜1.0)
138
+ 非同期で動作し、内部でキャッシュを活用します。
139
+ """
140
+ start = time.time()
141
  if not text1 or not text2:
142
  return 0.0
143
+ try:
144
+ emb1 = await _get_single_embedding(text1)
145
+ emb2 = await _get_single_embedding(text2)
146
+ # どちらも正規化済み -> dot product が cosine similarity
147
+ sim = float(np.dot(emb1, emb2))
148
+ elapsed = (time.time() - start) * 1000
149
+ logger.debug(f"Similarity computed in {elapsed:.1f}ms -> {sim:.6f}")
150
+ return round(sim, 6)
151
+ except Exception as e:
152
+ logger.exception("Error in calculate_similarity")
153
+ return 0.0
154
+
155
+ # ---------- Gradio アプリ ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  async def main():
157
+ ok = await load_model()
158
+ if not ok:
159
+ logger.error("Model load failed - exiting")
 
160
  return
161
 
162
  with gr.Blocks(theme=gr.themes.Default()) as demo:
163
+ gr.Markdown(f"# 高速テキスト類似度 API — `{MODEL_NAME}`\n高速化・キャッシュ・非同期実行を適用しています。")
164
+
 
 
 
 
 
 
165
  with gr.Tab("コンテキスト理解"):
166
+ gr.Markdown("## 2つの文章を入力して類似度を計算します")
167
  with gr.Row():
168
  text_input1 = gr.Textbox(label="文章1", lines=3, placeholder="例: 今日の天気は晴れです。")
169
  text_input2 = gr.Textbox(label="文章2", lines=3, placeholder="例: 今日は良い天気ですね。")
170
  calculate_button = gr.Button("類似度を計算", variant="primary")
171
+ similarity_output = gr.Number(label="コサイン類似度スコア")
172
+ calculate_button.click(fn=calculate_similarity, inputs=[text_input1, text_input2], outputs=similarity_output)
173
+
 
 
174
  with gr.Tab("埋め込みベクトル生成"):
175
  gr.Markdown("## テキスト(改行区切り)を入力して埋め込みベクトルを生成します")
176
+ texts_input = gr.Textbox(label="テキスト入力 (1行に1つの文章)", lines=6, placeholder="犬が公園を走っている。\n猫が窓際で日向ぼっこをしている。")
 
177
  generate_button = gr.Button("ベクトルを生成", variant="primary")
178
  embeddings_output = gr.JSON(label="生成された埋め込みベクトル")
179
  generate_button.click(fn=get_embeddings, inputs=texts_input, outputs=embeddings_output)
 
 
 
180
 
181
+ gr.Markdown("### ヘルスチェック")
182
+ gr.Textbox(value="OK", label="サービス状態 (静的表示)")
183
+
184
+ # API: calculate_similarity を HTTP API として公開(Gradio のバージョン依存)
185
+ try:
186
+ demo.api(calculate_similarity, api_name="calculate_similarity")
187
+ except Exception:
188
+ # 古い/新しい Gradio で API 作成方法が違うことがあるため安全にフォールバック
189
+ logger.debug("demo.api not supported in this environment; skipping automatic API registration")
190
+
191
+ demo.queue(max_size=GRADIO_QUEUE_SIZE, default_concurrency_limit=GRADIO_CONCURRENCY)
192
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=True, max_threads=MAX_WORKERS)
193
 
 
194
  if __name__ == "__main__":
195
+ asyncio.run(main())