summerstars commited on
Commit
f930370
·
verified ·
1 Parent(s): 232711a

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +125 -65
main.py CHANGED
@@ -1,84 +1,155 @@
1
  import gradio as gr
2
  from sentence_transformers import SentenceTransformer, util
3
  import torch
4
- import numpy as np
 
5
 
6
  # --- 1. モデルのロード ---
7
- # スクリプト起動時に一度だけモデルをロードする
8
  print("モデルをロードしています... (初回は時間がかかることがあります)")
9
- # GPUが利用可能ならGPUを、そうでなければCPUを使用
10
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
11
- model = SentenceTransformer("summerstars/MARK-Embedding", device=device)
12
- print(f"モデルのロードが完了しました。デバイス: {device}")
13
-
14
- # --- 2. APIのコア機能となる関数を定義 ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
 
16
  def calculate_similarity(text1, text2):
17
- """
18
- 2つのテキストのコサイン類似度を計算する関数
19
- """
20
- if not text1 or not text2:
21
- return 0.0
22
-
23
- # テキストをリストにまとめる
24
- sentences = [text1, text2]
25
-
26
- # モデルを使って埋め込みベクトルを計算
27
- # convert_to_tensor=TrueでPyTorchテン��ルとして結果を得る
28
- embeddings = model.encode(sentences, convert_to_tensor=True)
29
-
30
- # コサイン類似度を計算
31
  cosine_scores = util.cos_sim(embeddings[0], embeddings[1])
32
-
33
- # テンソルからfloat値を取り出して返す
34
  return round(cosine_scores.item(), 4)
35
 
36
-
37
  def get_embeddings(texts):
38
- """
39
- 改行区切りのテキストから埋め込みベクトルを生成する関数
40
- """
41
- if not texts.strip():
42
- return {}
43
-
44
- # 改行でテキストを分割し、空行は無視する
45
  sentences = [s.strip() for s in texts.strip().split('\n') if s.strip()]
46
-
47
- if not sentences:
48
- return {}
49
-
50
- # 埋め込みベクトルを計算
51
- # convert_to_numpy=TrueでNumpy配列として結果を得る
52
  embeddings = model.encode(sentences, convert_to_numpy=True)
 
53
 
54
- # テキストとベクトルのペアを辞書形式で作成
55
- # JSONで扱いやすいようにNumpy配列をリストに変換
56
- response_data = {
57
- sentence: embedding.tolist()
58
- for sentence, embedding in zip(sentences, embeddings)
59
- }
60
-
61
- return response_data
62
 
63
- # --- 3. Gradioインターフェースの構築 ---
64
 
65
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
66
  gr.Markdown(
67
  """
68
- # 埋め込みモデル `summerstars/MARK-Embedding` API
69
- 日本語のテキスト埋め込みモデルを使用して、文章の類似度計算やベクトル化を行うことができます。
70
  """
71
  )
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  with gr.Tab("文章の類似度計算"):
 
74
  gr.Markdown("## 2つの文章を入力して類似度を計算します")
75
  with gr.Row():
76
  text_input1 = gr.Textbox(label="文章1", lines=3, placeholder="例: 今日の天気は晴れです。")
77
  text_input2 = gr.Textbox(label="文章2", lines=3, placeholder="例: 今日は良い天気ですね。")
78
-
79
  calculate_button = gr.Button("類似度を計算", variant="primary")
80
  similarity_output = gr.Number(label="コサイン類似度スコア (値が高いほど類似)")
81
-
82
  calculate_button.click(
83
  fn=calculate_similarity,
84
  inputs=[text_input1, text_input2],
@@ -86,25 +157,14 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
86
  )
87
 
88
  with gr.Tab("埋め込みベクトル生成"):
 
89
  gr.Markdown("## テキスト(改行区切り)を入力して埋め込みベクトルを生成します")
90
  with gr.Row():
91
- texts_input = gr.Textbox(
92
- label="テキスト入力 (1行に1つの文章)",
93
- lines=5,
94
- placeholder="""犬が公園を走っている。
95
- 猫が窓際で日向ぼっこをしている。
96
- 明日の会議の資料を準備する。"""
97
- )
98
-
99
  generate_button = gr.Button("ベクトルを生成", variant="primary")
100
  embeddings_output = gr.JSON(label="生成された埋め込みベクトル")
 
101
 
102
- generate_button.click(
103
- fn=get_embeddings,
104
- inputs=texts_input,
105
- outputs=embeddings_output
106
- )
107
-
108
- # --- 4. Gradioアプリの起動 ---
109
  if __name__ == "__main__":
110
  demo.launch()
 
1
  import gradio as gr
2
  from sentence_transformers import SentenceTransformer, util
3
  import torch
4
+ import csv
5
+ import os
6
 
7
  # --- 1. モデルのロード ---
 
8
  print("モデルをロードしています... (初回は時間がかかることがあります)")
 
9
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
10
+ try:
11
+ model = SentenceTransformer("summerstars/MARK-Embedding", device=device)
12
+ print(f"モデルのロードが完了しました。デバイス: {device}")
13
+ except Exception as e:
14
+ print(f"モデルのロード中にエラーが発生しました: {e}")
15
+ print("インターネット接続を確認するか、モデル名が正しいか確認してください。")
16
+ exit()
17
+
18
+ # --- 2. subset.csvから不適切表現サンプルを抽出 ---
19
+ TOXIC_SAMPLES = []
20
+ CSV_FILE_PATH = "subset.csv"
21
+
22
+ # ファイルが存在するか確認
23
+ if not os.path.exists(CSV_FILE_PATH):
24
+ print(f"エラー: {CSV_FILE_PATH} が見つかりません。")
25
+ print("app.py と同じディレクトリに subset.csv を配置してください。")
26
+ exit() # ファイルがなければプログラムを終了
27
+
28
+ try:
29
+ with open(CSV_FILE_PATH, mode='r', encoding='utf-8') as f:
30
+ # ヘッダー付きCSVとして読み込む
31
+ reader = csv.DictReader(f)
32
+ for row in reader:
33
+ # 'Toxic' または 'Very Toxic' のスコアが1以上の場合、そのテキストをサンプルとして追加
34
+ # .get(key, 0) はキーが存在しない場合に0を返す安全な方法
35
+ if int(row.get('Toxic', 0)) > 0 or int(row.get('Very Toxic', 0)) > 0:
36
+ TOXIC_SAMPLES.append(row['text'])
37
+ except Exception as e:
38
+ print(f"CSVファイルの読み込み中にエラーが発生しました: {e}")
39
+ exit()
40
+
41
+ if not TOXIC_SAMPLES:
42
+ print("警告: CSVファイルから不適切表現のサンプルが見つかりませんでした。")
43
+ print("CSVファイルの'Toxic'または'Very Toxic'列に1以上の値を持つ行があるか確認してください。")
44
+ toxic_sample_embeddings = None
45
+ else:
46
+ print(f"データセットから {len(TOXIC_SAMPLES)} 件の不適切表現サンプルを抽出しました。")
47
+ # 抽出したサンプルのベクトルを事前に計算しておく
48
+ print("不適切表現サンプルのベクトルを計算しています...")
49
+ toxic_sample_embeddings = model.encode(TOXIC_SAMPLES, convert_to_tensor=True, device=device)
50
+ print("ベクトル計算が完了しました。")
51
+
52
+
53
+ # --- 3. APIのコア機能となる関数を定義 ---
54
+
55
+ def check_toxicity(text, threshold):
56
+ """入力テキストの不適切度を判定する関数"""
57
+ if not text.strip():
58
+ return "⚪️ テキスト未入力", "テキストを入力してから判定実行ボタンを押してください。"
59
+
60
+ if toxic_sample_embeddings is None:
61
+ return "判定不可", "判定用のサンプルデータがロードされていません。"
62
+
63
+ # 入力テキストの埋め込みを計算
64
+ text_embedding = model.encode(text, convert_to_tensor=True, device=device)
65
+
66
+ # 入力テキストと全てのサンプルベクトルとのコサイン類似度を計算
67
+ cos_scores = util.cos_sim(text_embedding, toxic_sample_embeddings)[0]
68
+
69
+ # 最も高い類似度スコアとそのインデックスを取得
70
+ top_score, top_idx = torch.max(cos_scores, dim=0)
71
+ top_score = top_score.item()
72
+ top_idx = top_idx.item()
73
+
74
+ # 最も類似した不適切表現サンプル
75
+ most_similar_sample = TOXIC_SAMPLES[top_idx]
76
+
77
+ # しきい値と比較して判定
78
+ if top_score >= threshold:
79
+ judgment = "🔴 不適切の可能性あり"
80
+ details = (
81
+ f"最も類似したサンプル文: '{most_similar_sample}'\n"
82
+ f"類似度スコア: {top_score:.4f}\n"
83
+ f"(判定しきい値: {threshold})"
84
+ )
85
+ else:
86
+ judgment = "🟢 問題なし"
87
+ details = (
88
+ f"最も類似したサンプル文: '{most_similar_sample}'\n"
89
+ f"類似度スコア: {top_score:.4f}\n"
90
+ f"(判定しきい値: {threshold})"
91
+ )
92
+
93
+ return judgment, details
94
 
95
+ # (以前の類似度計算とベクトル生成の関数はそのまま)
96
  def calculate_similarity(text1, text2):
97
+ if not text1 or not text2: return 0.0
98
+ embeddings = model.encode([text1, text2], convert_to_tensor=True)
 
 
 
 
 
 
 
 
 
 
 
 
99
  cosine_scores = util.cos_sim(embeddings[0], embeddings[1])
 
 
100
  return round(cosine_scores.item(), 4)
101
 
 
102
  def get_embeddings(texts):
103
+ if not texts.strip(): return {}
 
 
 
 
 
 
104
  sentences = [s.strip() for s in texts.strip().split('\n') if s.strip()]
105
+ if not sentences: return {}
 
 
 
 
 
106
  embeddings = model.encode(sentences, convert_to_numpy=True)
107
+ return {s: e.tolist() for s, e in zip(sentences, embeddings)}
108
 
 
 
 
 
 
 
 
 
109
 
110
+ # --- 4. Gradioインターフェースの構築 ---
111
 
112
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
113
  gr.Markdown(
114
  """
115
+ # テキスト解析API (`summerstars/MARK-Embedding`)
116
+ 日本語のテキスト埋め込みモデルを使用して、文章の類似度計算、ベクトル化、不適切表現の判定を行います。
117
  """
118
  )
119
 
120
+ with gr.Tab("不適切表現判定"):
121
+ gr.Markdown("## テキストが不適切表現に類似しているかを判定します")
122
+ gr.Markdown(
123
+ "**仕組み:** `subset.csv`内の'Toxic'/'Very Toxic'ラベルが付いた文章と、入力テキストの意味的な近さ(コサイン類似度)を計算します。"
124
+ "文脈を完全に理解するわけではないため、誤判定の可能性があります。あくまで参考としてご利用ください。"
125
+ )
126
+ with gr.Row():
127
+ toxicity_text_input = gr.Textbox(label="判定したいテキスト", lines=5, placeholder="ここに文章を入力してください...")
128
+
129
+ toxicity_threshold_slider = gr.Slider(
130
+ minimum=0.3, maximum=1.0, value=0.6, step=0.05, label="不適切と判定する類似度のしきい値"
131
+ )
132
+
133
+ toxicity_check_button = gr.Button("判定実行", variant="primary")
134
+
135
+ with gr.Row():
136
+ toxicity_judgment_output = gr.Textbox(label="判定結果", interactive=False, scale=1)
137
+ toxicity_details_output = gr.Textbox(label="判定詳細", interactive=False, scale=2)
138
+
139
+ toxicity_check_button.click(
140
+ fn=check_toxicity,
141
+ inputs=[toxicity_text_input, toxicity_threshold_slider],
142
+ outputs=[toxicity_judgment_output, toxicity_details_output]
143
+ )
144
+
145
  with gr.Tab("文章の類似度計算"):
146
+ # (変更なし)
147
  gr.Markdown("## 2つの文章を入力して類似度を計算します")
148
  with gr.Row():
149
  text_input1 = gr.Textbox(label="文章1", lines=3, placeholder="例: 今日の天気は晴れです。")
150
  text_input2 = gr.Textbox(label="文章2", lines=3, placeholder="例: 今日は良い天気ですね。")
 
151
  calculate_button = gr.Button("類似度を計算", variant="primary")
152
  similarity_output = gr.Number(label="コサイン類似度スコア (値が高いほど類似)")
 
153
  calculate_button.click(
154
  fn=calculate_similarity,
155
  inputs=[text_input1, text_input2],
 
157
  )
158
 
159
  with gr.Tab("埋め込みベクトル生成"):
160
+ # (変更なし)
161
  gr.Markdown("## テキスト(改行区切り)を入力して埋め込みベクトルを生成します")
162
  with gr.Row():
163
+ texts_input = gr.Textbox(label="テキスト入力 (1行に1つの文章)", lines=5, placeholder="犬が公園を走っている。\n猫が窓際で日向ぼっこをしている。")
 
 
 
 
 
 
 
164
  generate_button = gr.Button("ベクトルを生成", variant="primary")
165
  embeddings_output = gr.JSON(label="生成された埋め込みベクトル")
166
+ generate_button.click(fn=get_embeddings, inputs=texts_input, outputs=embeddings_output)
167
 
168
+ # --- 5. Gradioアプリの起動 ---
 
 
 
 
 
 
169
  if __name__ == "__main__":
170
  demo.launch()