summerstars commited on
Commit
629d674
·
verified ·
1 Parent(s): cb79eff

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +67 -68
main.py CHANGED
@@ -4,95 +4,100 @@ import torch
4
  import csv
5
  import os
6
 
7
- # --- 1. モデルのロード ---
8
- print("モデルをロードしています... (初回は時間がかかることがあります)")
 
 
 
 
 
 
 
9
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
10
  try:
11
- model = SentenceTransformer("summerstars/MARK-Embedding", device=device)
12
  print(f"モデルのロードが完了しました。デバイス: {device}")
13
  except Exception as e:
14
  print(f"モデルのロード中にエラーが発生しました: {e}")
15
- print("インターネット接続を確認するか、モデル名が正しいか確認してください。")
16
  exit()
17
 
18
- # --- 2. subset.csvから不適切表現サンプルを抽出 ---
19
- TOXIC_SAMPLES = []
20
- CSV_FILE_PATH = "subset.csv"
21
 
22
  # ファイルが存在するか確認
23
  if not os.path.exists(CSV_FILE_PATH):
24
  print(f"エラー: {CSV_FILE_PATH} が見つかりません。")
25
- print("app.py と同じディレクトリに subset.csv を配置してください。")
26
- exit() # ファイルがなければプログラムを終了
27
 
28
  try:
29
  with open(CSV_FILE_PATH, mode='r', encoding='utf-8') as f:
30
- # ヘッダー付きCSVとして読み込む
31
  reader = csv.DictReader(f)
32
  for row in reader:
33
- # 'Toxic' または 'Very Toxic' のスコア1以上の場合、そのテキストをサンプルして追加
34
- # .get(key, 0) はキーが存在しない場合に0を返す安全な方法
35
- if int(row.get('Toxic', 0)) > 0 or int(row.get('Very Toxic', 0)) > 0:
36
- TOXIC_SAMPLES.append(row['text'])
37
  except Exception as e:
38
  print(f"CSVファイルの読み込み中にエラーが発生しました: {e}")
39
  exit()
40
 
41
- if not TOXIC_SAMPLES:
42
- print("警告: CSVファイルから不適切表現のサンプルが見つかりませんでした。")
43
- print("CSVファイル'Toxic'または'Very Toxic'列に1以上値を持つ行があるか確認してください。")
44
- toxic_sample_embeddings = None
45
  else:
46
- print(f"データセットから {len(TOXIC_SAMPLES)} 件の不適切表現サンプルを抽出しました。")
47
- # 抽出したサンプルのベクトルを事前に計算しておく
48
- print("不適切表現サンプルのベクトルを計算しています...")
49
- toxic_sample_embeddings = model.encode(TOXIC_SAMPLES, convert_to_tensor=True, device=device)
50
  print("ベクトル計算が完了しました。")
51
 
52
 
53
- # --- 3. APIのコア機能となる関数を定義 ---
54
 
55
- def check_toxicity(text, threshold):
56
- """入力テキストの不適切度を判定する関数"""
57
  if not text.strip():
58
- return "⚪️ テキスト未入力", "テキストを入力してから判定実行ボタンを押してください。"
59
 
60
- if toxic_sample_embeddings is None:
61
  return "判定不可", "判定用のサンプルデータがロードされていません。"
62
 
63
  # 入力テキストの埋め込みを計算
64
  text_embedding = model.encode(text, convert_to_tensor=True, device=device)
65
 
66
  # 入力テキストと全てのサンプルベクトルとのコサイン類似度を計算
67
- cos_scores = util.cos_sim(text_embedding, toxic_sample_embeddings)[0]
68
 
69
  # 最も高い類似度スコ���とそのインデックスを取得
70
  top_score, top_idx = torch.max(cos_scores, dim=0)
71
- top_score = top_score.item()
72
- top_idx = top_idx.item()
73
 
74
- # 最も類似した不適切表現サンプル
75
- most_similar_sample = TOXIC_SAMPLES[top_idx]
 
76
 
77
- # しきい値と比較して判定
78
- if top_score >= threshold:
79
- judgment = "🔴 不適切の可能性あり"
80
  details = (
81
- f"最も類似したサンプル: '{most_similar_sample}'\n"
82
- f"類似度スコア: {top_score:.4f}\n"
83
- f"(判定しきい値: {threshold})"
84
  )
85
  else:
86
- judgment = "🟢 問題なし"
87
  details = (
88
- f"最も類似しサンプル文: '{most_similar_sample}'\n"
89
- f"類似度スコア: {top_score:.4f}\n"
90
- f"(判定しきい値: {threshold})"
91
  )
92
 
93
  return judgment, details
94
 
95
- # (以前類似度計算とベクトル生成の関数はそのまま)
96
  def calculate_similarity(text1, text2):
97
  if not text1 or not text2: return 0.0
98
  embeddings = model.encode([text1, text2], convert_to_tensor=True)
@@ -107,39 +112,35 @@ def get_embeddings(texts):
107
  return {s: e.tolist() for s, e in zip(sentences, embeddings)}
108
 
109
 
110
- # --- 4. Gradioインターフェースの構築 ---
111
 
112
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
113
  gr.Markdown(
114
- """
115
- # テキスト解析API (`summerstars/MARK-Embedding`)
116
- 日本語のテキスト埋め込みモデルを使用し文章の類似度計算、ベクトル化、不適切表現の判定を行います。
117
  """
118
  )
119
 
120
- with gr.Tab("不適切表現判定"):
121
- gr.Markdown("## テキストが不適切表現に類似しているかを判定します")
122
- gr.Markdown(
123
- "**仕組み:** `subset.csv`内の'Toxic'/'Very Toxic'ラベルが付いた文章と、入力テキストの意味的な近さ(コサイン類似度)を計算します。"
124
- "文脈を完全に理解するわけではないため、誤判定の可能性があります。あくまで参考としてご利用ください。"
125
- )
126
  with gr.Row():
127
- toxicity_text_input = gr.Textbox(label="判定したいテキスト", lines=5, placeholder="ここに文章を入力してください...")
128
 
129
- toxicity_threshold_slider = gr.Slider(
130
- minimum=0.3, maximum=1.0, value=0.6, step=0.05, label="不適切と判定する類似度のしきい値"
131
  )
132
 
133
- toxicity_check_button = gr.Button("判定実行", variant="primary")
134
 
135
  with gr.Row():
136
- toxicity_judgment_output = gr.Textbox(label="判定結果", interactive=False, scale=1)
137
- toxicity_details_output = gr.Textbox(label="判定詳細", interactive=False, scale=2)
138
 
139
- toxicity_check_button.click(
140
- fn=check_toxicity,
141
- inputs=[toxicity_text_input, toxicity_threshold_slider],
142
- outputs=[toxicity_judgment_output, toxicity_details_output]
143
  )
144
 
145
  with gr.Tab("文章の類似度計算"):
@@ -151,9 +152,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
151
  calculate_button = gr.Button("類似度を計算", variant="primary")
152
  similarity_output = gr.Number(label="コサイン類似度スコア (値が高いほど類似)")
153
  calculate_button.click(
154
- fn=calculate_similarity,
155
- inputs=[text_input1, text_input2],
156
- outputs=similarity_output
157
  )
158
 
159
  with gr.Tab("埋め込みベクトル生成"):
@@ -165,6 +164,6 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
165
  embeddings_output = gr.JSON(label="生成された埋め込みベクトル")
166
  generate_button.click(fn=get_embeddings, inputs=texts_input, outputs=embeddings_output)
167
 
168
- # --- 5. Gradioアプリの起動 ---
169
  if __name__ == "__main__":
170
  demo.launch()
 
4
  import csv
5
  import os
6
 
7
+ # --- 1. 設定項目 ---
8
+ # 読み込むCSVファイル名
9
+ CSV_FILE_PATH = "text_data.csv"
10
+ # 埋め込みモデル
11
+ MODEL_NAME = "summerstars/MARK-Embedding"
12
+
13
+
14
+ # --- 2. モデルのロード ---
15
+ print(f"モデル '{MODEL_NAME}' をロードしています...")
16
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
17
  try:
18
+ model = SentenceTransformer(MODEL_NAME, device=device)
19
  print(f"モデルのロードが完了しました。デバイス: {device}")
20
  except Exception as e:
21
  print(f"モデルのロード中にエラーが発生しました: {e}")
 
22
  exit()
23
 
24
+ # --- 3. CSVデータからサンプルを読み込み、ベクトル化 ---
25
+ SAMPLE_TEXTS = []
26
+ SAMPLE_CATEGORIES = []
27
 
28
  # ファイルが存在するか確認
29
  if not os.path.exists(CSV_FILE_PATH):
30
  print(f"エラー: {CSV_FILE_PATH} が見つかりません。")
31
+ print("app.py と同じディレクトリにCSVファイルを配置してください。")
32
+ exit()
33
 
34
  try:
35
  with open(CSV_FILE_PATH, mode='r', encoding='utf-8') as f:
 
36
  reader = csv.DictReader(f)
37
  for row in reader:
38
+ # '発言''カテゴリ'存在するこを確認
39
+ if '発言' in row and 'カテゴリ' in row and row['発言'].strip():
40
+ SAMPLE_TEXTS.append(row['発言'])
41
+ SAMPLE_CATEGORIES.append(row['カテゴリ'])
42
  except Exception as e:
43
  print(f"CSVファイルの読み込み中にエラーが発生しました: {e}")
44
  exit()
45
 
46
+ if not SAMPLE_TEXTS:
47
+ print(f"警告: {CSV_FILE_PATH} から判定用のサンプルが読み込めませんでした。")
48
+ print("CSVファイル'発言' 'カテゴリ' があり、データが含まれているか確認してください。")
49
+ sample_embeddings = None
50
  else:
51
+ print(f"データセットから {len(SAMPLE_TEXTS)} 件のサンプルを読み込みました。")
52
+ print("サンプルのベクトルを計算しています...")
53
+ sample_embeddings = model.encode(SAMPLE_TEXTS, convert_to_tensor=True, device=device)
 
54
  print("ベクトル計算が完了しました。")
55
 
56
 
57
+ # --- 4. APIのコア機能となる関数を定義 ---
58
 
59
+ def check_category_similarity(text, threshold):
60
+ """入力テキストがどカテゴリの発言に最も類似しているかを判定する関数"""
61
  if not text.strip():
62
+ return "⚪️ テキスト未入力", "テキストを入力してください。"
63
 
64
+ if sample_embeddings is None:
65
  return "判定不可", "判定用のサンプルデータがロードされていません。"
66
 
67
  # 入力テキストの埋め込みを計算
68
  text_embedding = model.encode(text, convert_to_tensor=True, device=device)
69
 
70
  # 入力テキストと全てのサンプルベクトルとのコサイン類似度を計算
71
+ cos_scores = util.cos_sim(text_embedding, sample_embeddings)[0]
72
 
73
  # 最も高い類似度スコ���とそのインデックスを取得
74
  top_score, top_idx = torch.max(cos_scores, dim=0)
75
+ top_score_item = top_score.item()
76
+ top_idx_item = top_idx.item()
77
 
78
+ # 最も類似したサンプルとそのカテゴリ
79
+ most_similar_text = SAMPLE_TEXTS[top_idx_item]
80
+ most_similar_category = SAMPLE_CATEGORIES[top_idx_item]
81
 
82
+ # しきい値に基づいて判定メッセージを作成
83
+ if top_score_item >= threshold:
84
+ judgment = f"🔴 {most_similar_category} の可能性"
85
  details = (
86
+ f"入力は「{most_similar_category}」カテゴリの発言に類似しています。\n\n"
87
+ f"最も類似したサンプル文: '{most_similar_text}'\n"
88
+ f"類似度スコア: {top_score_item:.4f} (しきい値: {threshold})"
89
  )
90
  else:
91
+ judgment = "🟢 特定のカテゴリとの強い類似性なし"
92
  details = (
93
+ f"最も近かっのは「{most_similar_category}」カテゴリの発言です。\n\n"
94
+ f"最も類似したサンプル文: '{most_similar_text}'\n"
95
+ f"類似度スコア: {top_score_item:.4f} (しきい値: {threshold})"
96
  )
97
 
98
  return judgment, details
99
 
100
+ # (既存機能の関数は変更なし)
101
  def calculate_similarity(text1, text2):
102
  if not text1 or not text2: return 0.0
103
  embeddings = model.encode([text1, text2], convert_to_tensor=True)
 
112
  return {s: e.tolist() for s, e in zip(sentences, embeddings)}
113
 
114
 
115
+ # --- 5. Gradioインターフェースの構築 ---
116
 
117
+ with gr.Blocks(theme=gr.themes.Default()) as demo:
118
  gr.Markdown(
119
+ f"""
120
+ # テキストカテゴリ類似性判定API
121
+ `{MODEL_NAME}` を使用し、入力テキストが `{CSV_FILE_PATH}` 内どのカテゴリの発言に類似しているかを判定ます。
122
  """
123
  )
124
 
125
+ with gr.Tab("カテゴリ類似性判定"):
126
+ gr.Markdown("## テキストがどのカテゴリの発言に類似しているかを判定します")
 
 
 
 
127
  with gr.Row():
128
+ text_input = gr.Textbox(label="判定したいテキスト", lines=5, placeholder="ここに文章を入力してください...")
129
 
130
+ threshold_slider = gr.Slider(
131
+ minimum=0.3, maximum=1.0, value=0.6, step=0.05, label="判定のしきい値 (この値以上の類似度で警告)"
132
  )
133
 
134
+ check_button = gr.Button("判定実行", variant="primary")
135
 
136
  with gr.Row():
137
+ judgment_output = gr.Textbox(label="判定結果", interactive=False, scale=1)
138
+ details_output = gr.Textbox(label="判定詳細", interactive=False, scale=2, lines=4)
139
 
140
+ check_button.click(
141
+ fn=check_category_similarity,
142
+ inputs=[text_input, threshold_slider],
143
+ outputs=[judgment_output, details_output]
144
  )
145
 
146
  with gr.Tab("文章の類似度計算"):
 
152
  calculate_button = gr.Button("類似度を計算", variant="primary")
153
  similarity_output = gr.Number(label="コサイン類似度スコア (値が高いほど類似)")
154
  calculate_button.click(
155
+ fn=calculate_similarity, inputs=[text_input1, text_input2], outputs=similarity_output
 
 
156
  )
157
 
158
  with gr.Tab("埋め込みベクトル生成"):
 
164
  embeddings_output = gr.JSON(label="生成された埋め込みベクトル")
165
  generate_button.click(fn=get_embeddings, inputs=texts_input, outputs=embeddings_output)
166
 
167
+ # --- 6. Gradioアプリの起動 ---
168
  if __name__ == "__main__":
169
  demo.launch()