Qifan Zhang commited on
Commit
b91a6bd
·
1 Parent(s): 051bcc8

Add Qwen3 embedding support with last-token pooling

Browse files
.codex/environments/environment.toml DELETED
@@ -1,11 +0,0 @@
1
- # THIS IS AUTOGENERATED. DO NOT EDIT MANUALLY
2
- version = 1
3
- name = "TransDis-CreativityAutoAssessment"
4
-
5
- [setup]
6
- script = ""
7
-
8
- [[actions]]
9
- name = "运行"
10
- icon = "run"
11
- command = "GRADIO_SERVER_PORT=7860 /Users/eric/.local/bin/uv run --python 3.11 --with gradio==6.14.0 --with-requirements requirements.txt python app.py"
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore CHANGED
@@ -1,6 +1,34 @@
1
- .idea
2
- flagged
3
- data/example
4
- data/tmp
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  output.csv
 
 
 
 
 
 
 
1
+ # Editor / OS noise
2
+ .DS_Store
3
+ .idea/
4
+ .vscode/
5
 
6
+ # Python caches
7
+ __pycache__/
8
+ *.py[cod]
9
+ *$py.class
10
+ .pytest_cache/
11
+ .ruff_cache/
12
+ .mypy_cache/
13
+ .coverage
14
+ htmlcov/
15
+
16
+ # Local environments
17
+ .env
18
+ .env.*
19
+ !.env.example
20
+ .venv/
21
+ venv/
22
+ env/
23
+
24
+ # Local app/runtime output
25
+ flagged/
26
+ data/example/
27
+ data/tmp/
28
  output.csv
29
+ outputs/
30
+ tmp/
31
+ temp/
32
+
33
+ # Local Codex app metadata
34
+ .codex/
AGENTS.md CHANGED
@@ -10,7 +10,7 @@
10
 
11
  - `app.py`:Gradio UI、输入解析、行数限制、任务分发、输出 CSV 生成。
12
  - `utils/pipeline.py`:`Originality` 和 `Flexibility` 的评分流程。
13
- - `utils/models.py`:模型列表、tokenizer/model 加载、pooling 行为。
14
  - `data/description.txt`:Space 页面中展示的长篇双语说明。
15
  - `.pre-commit-config.yaml`:本地 Ruff pre-commit hook。
16
 
@@ -34,7 +34,8 @@
34
  ## 模型与评分规则
35
 
36
  - 除非用户明确要求模型或评分变化,不要修改模型下拉列表、默认模型、pooling 选项或评分公式。
37
- - `ModelWithPooling` 使用 Hugging Face `AutoTokenizer` 和 `AutoModel`;UI 暴露的 pooling 模式是 `mean` `cls`
 
38
  - 避免在 app 启动时提前加载模型。模型下载和加载应只在处理请求时发生。
39
  - 做轻量本地测试时,优先 monkeypatch `pipeline.p0_originality` / `pipeline.p1_flexibility`,不要为了普通 smoke test 下载大模型;只有真实推理验证才下载模型。
40
 
 
10
 
11
  - `app.py`:Gradio UI、输入解析、行数限制、任务分发、输出 CSV 生成。
12
  - `utils/pipeline.py`:`Originality` 和 `Flexibility` 的评分流程。
13
+ - `utils/models.py`:模型列表、embedding adapter registry、tokenizer/model 加载、pooling 行为。
14
  - `data/description.txt`:Space 页面中展示的长篇双语说明。
15
  - `.pre-commit-config.yaml`:本地 Ruff pre-commit hook。
16
 
 
34
  ## 模型与评分规则
35
 
36
  - 除非用户明确要求模型或评分变化,不要修改模型下拉列表、默认模型、pooling 选项或评分公式。
37
+ - `utils.models.get_embedding_model()` 是评分流程的统一模型入口。legacy 模型走 `ModelWithPooling`,使用 Hugging Face `AutoTokenizer` 和 `AutoModel`;`Qwen/Qwen3-Embedding-0.6B` 走专用 adapter,自动使用官方推荐last-token pooling 和向量归一化
38
+ - UI 使用 Gradio `Blocks` 实现,但视觉结构应保持原来的左右两列:左侧输入、右侧输出。legacy 模型的 pooling 选项是 `mean` 和 `cls`;选择 `Qwen/Qwen3-Embedding-0.6B` 时,UI 应切换并锁定为 `last-token`,后端也应通过 `effective_pooling()` 兜底强制。
39
  - 避免在 app 启动时提前加载模型。模型下载和加载应只在处理请求时发生。
40
  - 做轻量本地测试时,优先 monkeypatch `pipeline.p0_originality` / `pipeline.p1_flexibility`,不要为了普通 smoke test 下载大模型;只有真实推理验证才下载模型。
41
 
README.md CHANGED
@@ -45,13 +45,13 @@ id,prompt,response
45
 
46
  ## 模型选择
47
 
48
- 应用会在下拉菜单中提供多语言、英文和中文 Transformer 检查点。默认模型是:
49
 
50
  ```text
51
  sentence-transformers/paraphrase-multilingual-mpnet-base-v2
52
  ```
53
 
54
- 可选 pooling 方式 `mean` 和 `cls`。如果使用 `bert-base-chinese`,建议选择 `mean` pooling。
55
 
56
  ## 本地开发
57
 
 
45
 
46
  ## 模型选择
47
 
48
+ 应用会在下拉菜单中提供多语言、英文和中文 Transformer 检查点,包括多语言 embedding 模型 `Qwen/Qwen3-Embedding-0.6B`。默认模型是:
49
 
50
  ```text
51
  sentence-transformers/paraphrase-multilingual-mpnet-base-v2
52
  ```
53
 
54
+ 可选 pooling 方式主要用于 legacy Transformer 检查点,包括 `mean` 和 `cls`。如果使用 `bert-base-chinese`,建议选择 `mean` pooling。选择 `Qwen/Qwen3-Embedding-0.6B` 时,界面会自动切换并锁定为 `last-token`,应用会使用该模型推荐的 last-token pooling 和向量归一化,即使 API 请求传入其他 pooling 值也会被覆盖。
55
 
56
  ## 本地开发
57
 
app.py CHANGED
@@ -2,14 +2,19 @@ import os
2
  import tempfile
3
  import traceback
4
  from io import StringIO
5
- from typing import Optional
6
 
7
  import gradio as gr
8
  import pandas as pd
9
  from loguru import logger
10
 
11
  from utils import pipeline
12
- from utils.models import list_models
 
 
 
 
 
13
 
14
 
15
  def resolve_file_path(file) -> str:
@@ -33,14 +38,35 @@ def read_data(filepath: str) -> Optional[pd.DataFrame]:
33
  return df
34
 
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def process(
37
  task_name: str,
38
  model_name: str,
39
  pooling: str,
40
  text: str,
41
  file=None,
42
- ) -> (None, pd.DataFrame, str):
43
  try:
 
44
  logger.info(f'Processing {task_name} with {model_name} and {pooling}')
45
  # load file
46
  if file:
@@ -56,6 +82,11 @@ def process(
56
  if len(df) > 10000:
57
  raise Exception('Data exceeds 10,000 rows')
58
 
 
 
 
 
 
59
  # process
60
  if task_name == 'Originality':
61
  df = pipeline.p0_originality(df, model_name, pooling)
@@ -68,7 +99,7 @@ def process(
68
  fd, path = tempfile.mkstemp(prefix='transdis_', suffix='.csv')
69
  os.close(fd)
70
  df.to_csv(path, index=False, encoding='utf-8-sig')
71
- return None, df.iloc[:10], path
72
 
73
  except Exception:
74
  error = traceback.format_exc()
@@ -80,7 +111,7 @@ def process(
80
  'text': text,
81
  'file': file,
82
  })
83
- return f'Something wrong\n\n{error}', None, None
84
 
85
 
86
  # input
@@ -96,8 +127,8 @@ model_name_dropdown = gr.components.Dropdown(
96
  )
97
  pooling_dropdown = gr.components.Dropdown(
98
  label='Pooling',
99
- value='mean',
100
- choices=['mean', 'cls']
101
  )
102
  text_input = gr.components.Textbox(
103
  value=open('data/example_xlm.csv', 'r').read(),
@@ -110,13 +141,34 @@ text_output = gr.components.Textbox(label='Output')
110
  dataframe_output = gr.components.Dataframe(label='DataFrame')
111
  file_output = gr.components.File(label='Output File', file_types=['.csv', '.xlsx'])
112
 
113
- app = gr.Interface(
114
- fn=process,
115
- inputs=[task_name_dropdown, model_name_dropdown, pooling_dropdown, text_input, file_input],
116
- outputs=[text_output, dataframe_output, file_output],
117
- description=open('data/description.txt', 'r').read(),
118
- title='TransDis-CreativityAutoAssessment',
119
- concurrency_limit=1,
120
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  if __name__ == '__main__':
122
  app.launch(max_threads=1)
 
2
  import tempfile
3
  import traceback
4
  from io import StringIO
5
+ from typing import Generator, Optional
6
 
7
  import gradio as gr
8
  import pandas as pd
9
  from loguru import logger
10
 
11
  from utils import pipeline
12
+ from utils.models import QWEN3_EMBEDDING_MODEL, get_embedding_model, list_models
13
+
14
+ LEGACY_POOLING_CHOICES = ['mean', 'cls']
15
+ QWEN3_POOLING_CHOICES = ['last-token']
16
+ DEFAULT_POOLING = 'mean'
17
+ QWEN3_POOLING = 'last-token'
18
 
19
 
20
  def resolve_file_path(file) -> str:
 
38
  return df
39
 
40
 
41
+ def effective_pooling(model_name: str, pooling: str) -> str:
42
+ if model_name == QWEN3_EMBEDDING_MODEL:
43
+ return QWEN3_POOLING
44
+ return pooling
45
+
46
+
47
+ def update_pooling_for_model(model_name: str):
48
+ if model_name == QWEN3_EMBEDDING_MODEL:
49
+ return gr.update(
50
+ choices=QWEN3_POOLING_CHOICES,
51
+ value=QWEN3_POOLING,
52
+ interactive=False,
53
+ )
54
+ return gr.update(
55
+ choices=LEGACY_POOLING_CHOICES,
56
+ value=DEFAULT_POOLING,
57
+ interactive=True,
58
+ )
59
+
60
+
61
  def process(
62
  task_name: str,
63
  model_name: str,
64
  pooling: str,
65
  text: str,
66
  file=None,
67
+ ) -> Generator[tuple[str, Optional[pd.DataFrame], Optional[str]], None, None]:
68
  try:
69
+ pooling = effective_pooling(model_name, pooling)
70
  logger.info(f'Processing {task_name} with {model_name} and {pooling}')
71
  # load file
72
  if file:
 
82
  if len(df) > 10000:
83
  raise Exception('Data exceeds 10,000 rows')
84
 
85
+ yield f'模型加载中:{model_name}', None, None
86
+ get_embedding_model(model_name)
87
+
88
+ yield '计算中...', None, None
89
+
90
  # process
91
  if task_name == 'Originality':
92
  df = pipeline.p0_originality(df, model_name, pooling)
 
99
  fd, path = tempfile.mkstemp(prefix='transdis_', suffix='.csv')
100
  os.close(fd)
101
  df.to_csv(path, index=False, encoding='utf-8-sig')
102
+ yield '完成', df.iloc[:10], path
103
 
104
  except Exception:
105
  error = traceback.format_exc()
 
111
  'text': text,
112
  'file': file,
113
  })
114
+ yield f'Something wrong\n\n{error}', None, None
115
 
116
 
117
  # input
 
127
  )
128
  pooling_dropdown = gr.components.Dropdown(
129
  label='Pooling',
130
+ value=DEFAULT_POOLING,
131
+ choices=LEGACY_POOLING_CHOICES
132
  )
133
  text_input = gr.components.Textbox(
134
  value=open('data/example_xlm.csv', 'r').read(),
 
141
  dataframe_output = gr.components.Dataframe(label='DataFrame')
142
  file_output = gr.components.File(label='Output File', file_types=['.csv', '.xlsx'])
143
 
144
+ with gr.Blocks(title='TransDis-CreativityAutoAssessment') as app:
145
+ gr.Markdown('# TransDis-CreativityAutoAssessment')
146
+ gr.Markdown(open('data/description.txt', 'r').read())
147
+ with gr.Row():
148
+ with gr.Column():
149
+ task_name_dropdown.render()
150
+ model_name_dropdown.render()
151
+ pooling_dropdown.render()
152
+ text_input.render()
153
+ file_input.render()
154
+ submit_button = gr.Button('Submit', variant='primary')
155
+ with gr.Column():
156
+ text_output.render()
157
+ dataframe_output.render()
158
+ file_output.render()
159
+
160
+ model_name_dropdown.change(
161
+ fn=update_pooling_for_model,
162
+ inputs=model_name_dropdown,
163
+ outputs=pooling_dropdown,
164
+ )
165
+ submit_button.click(
166
+ fn=process,
167
+ inputs=[task_name_dropdown, model_name_dropdown, pooling_dropdown, text_input, file_input],
168
+ outputs=[text_output, dataframe_output, file_output],
169
+ api_name='predict',
170
+ concurrency_limit=1,
171
+ )
172
+
173
  if __name__ == '__main__':
174
  app.launch(max_threads=1)
data/description.txt CHANGED
@@ -1,5 +1,5 @@
1
- TransDis系统,是一个基于Transformer语言模型的语义距离评分系统,用于自动评估中文(或其他语言)的多用途任务(AUT)中的独创性和灵活性(论文见,https://link.springer.com/article/10.3758/s13428-023-02313-z )。 输入被试(id)+提示词+回答的数据,每行1个用途,用逗号隔开。您可以通过文本框直接输入数据,也可以上传用逗号隔开的CSV格式文件或xlsx文件作为输入,CSV输入优先级高于文本框输入。 您可以选择用于评分的模型,请注意sentence-transformers_paraphrase-multilingual-mpnet-base-v2sentence-transformers_paraphrase-multilingual-MiniLM-L12-v2可用于多语言,其他模型仅适用于英文或中文。 我们提供Pooling方法的选择,对于bert-base-chinese建议使用mean pooling。 如发生错误,请试着简化你的数据——用更少的行试试。如果不行,则可能是输入格式错误,请尝试重新保存为逗号分隔的CSV,然后再上传CSV文件。 如运行较慢,可以复制此空间至您的帐号(我们建议这种方式),并选择升级版的硬件以提升处理速度。 如需更多帮助或报告其他bug,请联系ydd409@163.com。
2
 
3
- TranDis, a semantic distance scoring system based on transformer-based language models, can be a useful tool to automatically assess originality and flexibility for AUT in Chinese or other languages (see the paper at https://link.springer.com/article/10.3758/s13428-023-02313-z). Enter your participant ID + prompt + response data, one per line, with a COMMA between each variable. You can either input data directly into the text box or upload a comma-separated CSV file or a XLSX file as input. Please note that if both methods are used, the CSV input will take precedence over the text box input. You can choose the model to use for scoring. Please note that sentence-transformers_paraphrase-multilingual-mpnet-base-v2 and sentence-transformers_paraphrase-multilingual-MiniLM-L12-v2 are applicable to multiple languages; cyclone_simcse-chinese-roberta-wwm-ext is only applicable to Chinese; sentence-transformers/all-mpnet-base-v2 and sentence-transformers/all-MiniLM-L12-v2 are only applicable to English. If an error occurs, try simplifying your data - does it work with fewer rows? If not, the input format may be incorrect. If the process is sluggish, you have the option to duplicate this space to your account (we recommend this approach) and choose an enhanced hardware configuration for improved processing speed. For more assistance or to report potential issues with our system, please contact ydd409@163.com.
4
 
5
- Reference: Yang, T., Zhang, Q., Sun, Z., & Hou, Y. (2023). Automatic Assessment of Divergent Thinking in Chinese Language with TransDis: A Transformer-Based Language Model Approach. Behavior Research Methods. https://doi.org/10.3758/s13428-023-02313-z
 
1
+ TransDis系统,是一个基于Transformer语言模型的语义距离评分系统,用于自动评估中文(或其他语言)的多用途任务(AUT)中的独创性和灵活性(论文见,https://link.springer.com/article/10.3758/s13428-023-02313-z )。 输入被试(id)+提示词+回答的数据,每行1个用途,用逗号隔开。您可以通过文本框直接输入数据,也可以上传用逗号隔开的CSV格式文件或xlsx文件作为输入,CSV输入优先级高于文本框输入。 您可以选择用于评分的模型,请注意sentence-transformers_paraphrase-multilingual-mpnet-base-v2sentence-transformers_paraphrase-multilingual-MiniLM-L12-v2和Qwen/Qwen3-Embedding-0.6B可用于多语言,其他模型仅适用于英文或中文。 我们提供Pooling方法的选择,对于bert-base-chinese建议使用mean pooling;选择Qwen/Qwen3-Embedding-0.6B时,界面会自动切换并锁定为last-token,系统会使用该模型推荐的pooling和归一化方式。 如发生错误,请试着简化你的数据——用更少的行试试。如果不行,则可能是输入格式错误,请尝试重新保存为逗号分隔的CSV,然后再上传CSV文件。 如运行较慢,可以复制此空间至您的帐号(我们建议这种方式),并选择升级版的硬件以提升处理速度。 如需更多帮助或报告其他bug,请联系ydd409@163.com。
2
 
3
+ TranDis, a semantic distance scoring system based on transformer-based language models, can be a useful tool to automatically assess originality and flexibility for AUT in Chinese or other languages (see the paper at https://link.springer.com/article/10.3758/s13428-023-02313-z). Enter your participant ID + prompt + response data, one per line, with a COMMA between each variable. You can either input data directly into the text box or upload a comma-separated CSV file or a XLSX file as input. Please note that if both methods are used, the CSV input will take precedence over the text box input. You can choose the model to use for scoring. Please note that sentence-transformers_paraphrase-multilingual-mpnet-base-v2, sentence-transformers_paraphrase-multilingual-MiniLM-L12-v2, and Qwen/Qwen3-Embedding-0.6B are applicable to multiple languages; cyclone_simcse-chinese-roberta-wwm-ext is only applicable to Chinese; sentence-transformers/all-mpnet-base-v2 and sentence-transformers/all-MiniLM-L12-v2 are only applicable to English. The pooling selector applies to legacy models; when Qwen/Qwen3-Embedding-0.6B is selected, the UI switches and locks the selector to last-token, and the backend uses its recommended pooling and vector normalization. If an error occurs, try simplifying your data - does it work with fewer rows? If not, the input format may be incorrect. If the process is sluggish, you have the option to duplicate this space to your account (we recommend this approach) and choose an enhanced hardware configuration for improved processing speed. For more assistance or to report potential issues with our system, please contact ydd409@163.com.
4
 
5
+ Reference: Yang, T., Zhang, Q., Sun, Z., & Hou, Y. (2023). Automatic Assessment of Divergent Thinking in Chinese Language with TransDis: A Transformer-Based Language Model Approach. Behavior Research Methods. https://doi.org/10.3758/s13428-023-02313-z
requirements.txt CHANGED
@@ -1,7 +1,7 @@
1
  numpy<3
2
  pandas>=2.2,<4
3
 
4
- transformers>=4.45,<6
5
  sentence-transformers>=3,<6
6
 
7
  torch>=2.4,<3
 
1
  numpy<3
2
  pandas>=2.2,<4
3
 
4
+ transformers>=4.51,<6
5
  sentence-transformers>=3,<6
6
 
7
  torch>=2.4,<3
tests/test_app.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+
3
+ import app
4
+ from utils.models import QWEN3_EMBEDDING_MODEL
5
+
6
+
7
+ class PoolingUiTest(unittest.TestCase):
8
+ def test_pooling_dropdown_defaults_to_legacy_choices(self):
9
+ self.assertEqual(app.pooling_dropdown.value, 'mean')
10
+ self.assertEqual(app.pooling_dropdown.choices, [('mean', 'mean'), ('cls', 'cls')])
11
+
12
+ def test_qwen3_pooling_update_forces_last_token(self):
13
+ update = app.update_pooling_for_model(QWEN3_EMBEDDING_MODEL)
14
+
15
+ self.assertEqual(update['choices'], ['last-token'])
16
+ self.assertEqual(update['value'], 'last-token')
17
+ self.assertFalse(update['interactive'])
18
+
19
+ def test_legacy_pooling_update_restores_mean_cls(self):
20
+ update = app.update_pooling_for_model('bert-base-chinese')
21
+
22
+ self.assertEqual(update['choices'], ['mean', 'cls'])
23
+ self.assertEqual(update['value'], 'mean')
24
+ self.assertTrue(update['interactive'])
25
+
26
+ def test_qwen3_effective_pooling_ignores_api_pooling_value(self):
27
+ self.assertEqual(app.effective_pooling(QWEN3_EMBEDDING_MODEL, 'mean'), 'last-token')
28
+
29
+
30
+ if __name__ == '__main__':
31
+ unittest.main()
tests/test_models.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from types import SimpleNamespace
2
+ import unittest
3
+ from unittest.mock import patch
4
+
5
+ import torch
6
+
7
+ from utils import models
8
+
9
+
10
+ class FakeTokenizer:
11
+ def __init__(self):
12
+ self.calls = []
13
+
14
+ def __call__(
15
+ self,
16
+ text,
17
+ padding,
18
+ truncation,
19
+ max_length,
20
+ return_tensors,
21
+ ):
22
+ self.calls.append(
23
+ {
24
+ 'text': text,
25
+ 'padding': padding,
26
+ 'truncation': truncation,
27
+ 'max_length': max_length,
28
+ 'return_tensors': return_tensors,
29
+ }
30
+ )
31
+ return {
32
+ 'input_ids': torch.tensor([[101, 102, 0]]),
33
+ 'attention_mask': torch.tensor([[1, 1, 0]]),
34
+ }
35
+
36
+
37
+ class FakeModel:
38
+ def __init__(self):
39
+ self.device = torch.device('cpu')
40
+ self.eval_called = False
41
+
42
+ def to(self, device):
43
+ self.device = torch.device(device)
44
+ return self
45
+
46
+ def eval(self):
47
+ self.eval_called = True
48
+ return self
49
+
50
+ def __call__(self, **inputs):
51
+ hidden_states = torch.tensor(
52
+ [[[3.0, 0.0], [0.0, 4.0], [5.0, 12.0]]],
53
+ device=self.device,
54
+ )
55
+ return SimpleNamespace(last_hidden_state=hidden_states)
56
+
57
+
58
+ class Qwen3EmbeddingTest(unittest.TestCase):
59
+ def tearDown(self):
60
+ models.get_embedding_model.cache_clear()
61
+
62
+ def test_qwen3_model_is_available_without_changing_default(self):
63
+ self.assertEqual(
64
+ models.list_models[0],
65
+ 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2',
66
+ )
67
+ self.assertIn(models.QWEN3_EMBEDDING_MODEL, models.list_models)
68
+
69
+ def test_qwen3_uses_official_pooling_shape_and_normalization(self):
70
+ tokenizer = FakeTokenizer()
71
+ model = FakeModel()
72
+ models.get_embedding_model.cache_clear()
73
+
74
+ with (
75
+ patch.object(models.AutoTokenizer, 'from_pretrained', return_value=tokenizer) as load_tokenizer,
76
+ patch.object(models.AutoModel, 'from_pretrained', return_value=model) as load_model,
77
+ ):
78
+ embedding_model = models.get_embedding_model(models.QWEN3_EMBEDDING_MODEL)
79
+ embedding = embedding_model('hello', pooling='cls')
80
+
81
+ load_tokenizer.assert_called_once_with(
82
+ models.QWEN3_EMBEDDING_MODEL,
83
+ padding_side='left',
84
+ )
85
+ load_model.assert_called_once_with(models.QWEN3_EMBEDDING_MODEL)
86
+ self.assertTrue(model.eval_called)
87
+ self.assertEqual(tokenizer.calls[0]['max_length'], 8192)
88
+ self.assertEqual(tuple(embedding.shape), (2,))
89
+ torch.testing.assert_close(embedding.cpu(), torch.tensor([0.0, 1.0]))
90
+ self.assertAlmostEqual(torch.linalg.vector_norm(embedding).item(), 1.0)
91
+
92
+
93
+ if __name__ == '__main__':
94
+ unittest.main()
tests/test_pipeline.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ from unittest.mock import patch
3
+
4
+ import pandas as pd
5
+ import torch
6
+
7
+ from utils import pipeline
8
+
9
+
10
+ class FakeEmbeddingModel:
11
+ def __init__(self, vectors):
12
+ self.vectors = vectors
13
+ self.calls = []
14
+
15
+ def __call__(self, text, pooling='mean'):
16
+ self.calls.append((text, pooling))
17
+ return torch.tensor(self.vectors[text])
18
+
19
+
20
+ class PipelineFactoryTest(unittest.TestCase):
21
+ def test_originality_uses_embedding_factory(self):
22
+ model = FakeEmbeddingModel(
23
+ {
24
+ 'prompt': [1.0, 0.0],
25
+ 'response': [0.0, 1.0],
26
+ }
27
+ )
28
+ df = pd.DataFrame({'prompt': ['prompt'], 'response': ['response']})
29
+
30
+ with patch.object(pipeline, 'get_embedding_model', return_value=model) as factory:
31
+ result = pipeline.p0_originality(df, 'fake-model', 'mean')
32
+
33
+ factory.assert_called_once_with('fake-model')
34
+ self.assertAlmostEqual(result.loc[0, 'originality'], 1.0)
35
+ self.assertEqual(model.calls, [('prompt', 'mean'), ('response', 'mean')])
36
+
37
+ def test_flexibility_uses_embedding_factory(self):
38
+ model = FakeEmbeddingModel(
39
+ {
40
+ 'p': [1.0, 0.0],
41
+ 'a': [1.0, 0.0],
42
+ 'b': [0.0, 1.0],
43
+ }
44
+ )
45
+ df = pd.DataFrame(
46
+ {
47
+ 'id': [1, 1, 1],
48
+ 'prompt': ['p', 'p', 'p'],
49
+ 'response': ['a', 'b', 'a'],
50
+ }
51
+ )
52
+
53
+ with patch.object(pipeline, 'get_embedding_model', return_value=model) as factory:
54
+ result = pipeline.p1_flexibility(df, 'fake-model', 'cls')
55
+
56
+ factory.assert_called_once_with('fake-model')
57
+ self.assertEqual(len(result), 1)
58
+ self.assertAlmostEqual(result.loc[0, 'flexibility'], 2.0)
59
+ self.assertEqual(model.calls, [('a', 'cls'), ('b', 'cls'), ('a', 'cls')])
60
+
61
+
62
+ if __name__ == '__main__':
63
+ unittest.main()
utils/__pycache__/models.cpython-311.pyc DELETED
Binary file (5.89 kB)
 
utils/__pycache__/pipeline.cpython-311.pyc DELETED
Binary file (3.87 kB)
 
utils/models.py CHANGED
@@ -1,13 +1,15 @@
1
  from functools import lru_cache
2
 
3
  import torch
 
4
  from loguru import logger
5
  from sentence_transformers import SentenceTransformer
6
  from transformers import AutoTokenizer, AutoModel
7
 
8
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
 
9
 
10
- list_models = [
11
  'sentence-transformers/paraphrase-multilingual-mpnet-base-v2',
12
  'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2',
13
  'sentence-transformers/all-mpnet-base-v2',
@@ -17,6 +19,8 @@ list_models = [
17
  'IDEA-CCNL/Erlangshen-SimCSE-110M-Chinese',
18
  ]
19
 
 
 
20
 
21
  class SBert:
22
  def __init__(self, path):
@@ -68,6 +72,52 @@ class ModelWithPooling:
68
  return o
69
 
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  def test_sbert():
72
  m = SBert('bert-base-chinese')
73
  o = m('hello')
 
1
  from functools import lru_cache
2
 
3
  import torch
4
+ import torch.nn.functional as F
5
  from loguru import logger
6
  from sentence_transformers import SentenceTransformer
7
  from transformers import AutoTokenizer, AutoModel
8
 
9
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
10
+ QWEN3_EMBEDDING_MODEL = 'Qwen/Qwen3-Embedding-0.6B'
11
 
12
+ LEGACY_MODELS = [
13
  'sentence-transformers/paraphrase-multilingual-mpnet-base-v2',
14
  'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2',
15
  'sentence-transformers/all-mpnet-base-v2',
 
19
  'IDEA-CCNL/Erlangshen-SimCSE-110M-Chinese',
20
  ]
21
 
22
+ list_models = [*LEGACY_MODELS, QWEN3_EMBEDDING_MODEL]
23
+
24
 
25
  class SBert:
26
  def __init__(self, path):
 
72
  return o
73
 
74
 
75
+ class Qwen3Embedding:
76
+ def __init__(self, path):
77
+ logger.info(f'Start loading {self.__class__} from {path} ...')
78
+ self.tokenizer = AutoTokenizer.from_pretrained(path, padding_side='left')
79
+ self.model = AutoModel.from_pretrained(path)
80
+ self.model.to(DEVICE)
81
+ self.model.eval()
82
+ logger.info(f'Load {self.__class__} from {path} ...')
83
+
84
+ @staticmethod
85
+ def last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
86
+ left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0]
87
+ if left_padding:
88
+ return last_hidden_states[:, -1]
89
+
90
+ sequence_lengths = attention_mask.sum(dim=1) - 1
91
+ batch_size = last_hidden_states.shape[0]
92
+ return last_hidden_states[
93
+ torch.arange(batch_size, device=last_hidden_states.device),
94
+ sequence_lengths,
95
+ ]
96
+
97
+ @lru_cache(maxsize=100)
98
+ @torch.no_grad()
99
+ def __call__(self, text: str, pooling='mean'):
100
+ inputs = self.tokenizer(
101
+ text,
102
+ padding=True,
103
+ truncation=True,
104
+ max_length=8192,
105
+ return_tensors='pt',
106
+ )
107
+ inputs = {key: value.to(DEVICE) for key, value in inputs.items()}
108
+ outputs = self.model(**inputs)
109
+ embeddings = self.last_token_pool(outputs.last_hidden_state, inputs['attention_mask'])
110
+ embeddings = F.normalize(embeddings, p=2, dim=1)
111
+ return embeddings.squeeze(0)
112
+
113
+
114
+ @lru_cache(maxsize=8)
115
+ def get_embedding_model(model_name: str):
116
+ if model_name == QWEN3_EMBEDDING_MODEL:
117
+ return Qwen3Embedding(model_name)
118
+ return ModelWithPooling(model_name)
119
+
120
+
121
  def test_sbert():
122
  m = SBert('bert-base-chinese')
123
  o = m('hello')
utils/pipeline.py CHANGED
@@ -3,7 +3,7 @@ from typing import List
3
  import pandas as pd
4
  from sentence_transformers.util import cos_sim
5
 
6
- from utils.models import ModelWithPooling
7
 
8
 
9
  def p0_originality(df: pd.DataFrame, model_name: str, pooling: str) -> pd.DataFrame:
@@ -15,7 +15,7 @@ def p0_originality(df: pd.DataFrame, model_name: str, pooling: str) -> pd.DataFr
15
  """
16
  assert 'prompt' in df.columns
17
  assert 'response' in df.columns
18
- model = ModelWithPooling(model_name)
19
 
20
  def get_cos_sim(prompt: str, response: str) -> float:
21
  prompt_vec = model(text=prompt, pooling=pooling)
@@ -37,7 +37,7 @@ def p1_flexibility(df: pd.DataFrame, model_name: str, pooling: str) -> pd.DataFr
37
  assert 'prompt' in df.columns
38
  assert 'response' in df.columns
39
  assert 'id' in df.columns
40
- model = ModelWithPooling(model_name)
41
 
42
  def get_flexibility(responses: List[str]) -> float:
43
  responses_vec = [model(text=_, pooling=pooling) for _ in responses]
 
3
  import pandas as pd
4
  from sentence_transformers.util import cos_sim
5
 
6
+ from utils.models import get_embedding_model
7
 
8
 
9
  def p0_originality(df: pd.DataFrame, model_name: str, pooling: str) -> pd.DataFrame:
 
15
  """
16
  assert 'prompt' in df.columns
17
  assert 'response' in df.columns
18
+ model = get_embedding_model(model_name)
19
 
20
  def get_cos_sim(prompt: str, response: str) -> float:
21
  prompt_vec = model(text=prompt, pooling=pooling)
 
37
  assert 'prompt' in df.columns
38
  assert 'response' in df.columns
39
  assert 'id' in df.columns
40
+ model = get_embedding_model(model_name)
41
 
42
  def get_flexibility(responses: List[str]) -> float:
43
  responses_vec = [model(text=_, pooling=pooling) for _ in responses]