JiachenFu commited on
Commit
6600152
·
0 Parent(s):

update: app

Browse files
Files changed (9) hide show
  1. .gitignore +165 -0
  2. .gradio/certificate.pem +31 -0
  3. .gradio/flagged/dataset1.csv +64 -0
  4. DetectAnyLLM +1 -0
  5. LICENSE +35 -0
  6. README.md +14 -0
  7. app.py +539 -0
  8. core/model.py +255 -0
  9. requirements.txt +10 -0
.gitignore ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ ckpt/*/
12
+ logs/*/
13
+ models/*/
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ share/python-wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+
32
+ # PyInstaller
33
+ # Usually these files are written by a python script from a template
34
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
35
+ *.manifest
36
+ *.spec
37
+
38
+ # Installer logs
39
+ pip-log.txt
40
+ pip-delete-this-directory.txt
41
+
42
+ # Unit test / coverage reports
43
+ htmlcov/
44
+ .tox/
45
+ .nox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ *.py,cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+ cover/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ .pybuilder/
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ # For a library or package, you might want to ignore these files since the code is
90
+ # intended to run in multiple environments; otherwise, check them in:
91
+ # .python-version
92
+
93
+ # pipenv
94
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
96
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
97
+ # install all needed dependencies.
98
+ #Pipfile.lock
99
+
100
+ # poetry
101
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
103
+ # commonly ignored for libraries.
104
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105
+ #poetry.lock
106
+
107
+ # pdm
108
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109
+ #pdm.lock
110
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111
+ # in version control.
112
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
113
+ .pdm.toml
114
+ .pdm-python
115
+ .pdm-build/
116
+
117
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
118
+ __pypackages__/
119
+
120
+ # Celery stuff
121
+ celerybeat-schedule
122
+ celerybeat.pid
123
+
124
+ # SageMath parsed files
125
+ *.sage.py
126
+
127
+ # Environments
128
+ .env
129
+ .venv
130
+ env/
131
+ venv/
132
+ ENV/
133
+ env.bak/
134
+ venv.bak/
135
+
136
+ # Spyder project settings
137
+ .spyderproject
138
+ .spyproject
139
+
140
+ # Rope project settings
141
+ .ropeproject
142
+
143
+ # mkdocs documentation
144
+ /site
145
+
146
+ # mypy
147
+ .mypy_cache/
148
+ .dmypy.json
149
+ dmypy.json
150
+
151
+ # Pyre type checker
152
+ .pyre/
153
+
154
+ # pytype static type analyzer
155
+ .pytype/
156
+
157
+ # Cython debug symbols
158
+ cython_debug/
159
+
160
+ # PyCharm
161
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
162
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
163
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
164
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
165
+ #.idea/
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
.gradio/flagged/dataset1.csv ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ input_text,output,timestamp
2
+ "'def greet(input_text):
3
+ sub_texts = text_splitter.split_text(input_text) # 修改为split_text获取文本列表
4
+ html_output = []
5
+ for sub_text in sub_texts:
6
+ tokenized = scoring_tokenizer(sub_text, truncation=True, return_tensors=""pt"", padding=True, return_token_type_ids=False).to(device)
7
+ labels = tokenized.input_ids[:, 1:]
8
+ with torch.no_grad():
9
+ logits_score = scoring_model(**tokenized).logits[:, :-1]
10
+ logits_ref = logits_score
11
+ crit, _ = criterion_fn(logits_ref, logits_score, labels)
12
+
13
+ crit = crit.cpu().numpy().item()
14
+ prob = prob_estimator.crit_to_prob(crit)
15
+
16
+ # 根据概率值设置颜色
17
+ if prob >= 0.7:
18
+ color = ""red""
19
+ elif prob >= 0.3:
20
+ color = ""orange""
21
+ else:
22
+ color = ""white""
23
+
24
+ # 创建带样式的HTML内容
25
+ html_output.append(f'<span style=""color: {color};"">{sub_text} (Probability: {prob:.2f})</span>')
26
+
27
+ # 用换行连接所有结果
28
+ return ""<br>"".join(html_output)
29
+
30
+ demo = gr.Interface(
31
+ fn=greet,
32
+ inputs=[""text""],
33
+ outputs=gr.HTML() # 修改为HTML输出组件
34
+ )","'<span style=""color: white;"">def greet(input_text):
35
+ sub_texts = text_splitter.split_text(input_text) # 修改为split_text获取文本列表
36
+ html_output = []
37
+ for (Probability: 0.09)</span><br><span style=""color: white;"">sub_text in sub_texts:
38
+ tokenized = scoring_tokenizer(sub_text, truncation=True, return_tensors=""pt"", padding=True, retur (Probability: 0.03)</span><br><span style=""color: white;"">n_token_type_ids=False).to(device)
39
+ labels = tokenized.input_ids[:, 1:]
40
+ with torch.no_grad():
41
+ logits_ (Probability: 0.05)</span><br><span style=""color: white;"">score = scoring_model(**tokenized).logits[:, :-1]
42
+ logits_ref = logits_score
43
+ crit, _ = criterion_fn(logit (Probability: 0.00)</span><br><span style=""color: white;"">s_ref, logits_score, labels)
44
+
45
+ crit = crit.cpu().numpy().item()
46
+ prob = prob_estimator.crit_to_prob(crit) (Probability: 0.02)</span><br><span style=""color: white;""># 根据概率值设置颜色
47
+ if prob >= 0.7:
48
+ color = ""red""
49
+ elif prob >= 0.3:
50
+ color = ""or (Probability: 0.09)</span><br><span style=""color: white;"">ange""
51
+ else:
52
+ color = ""white""
53
+
54
+ # 创建带样式的HTML内容
55
+ html_output.append(f'<span style=""color: (Probability: 0.19)</span><br><span style=""color: white;"">{color};"">{sub_text} (Probability: {prob:.2f})</span>')
56
+
57
+ # 用换行连接所有结果
58
+ return ""<br>"".join(html_output)
59
+
60
+ demo = gr.Int (Probability: 0.01)</span><br><span style=""color: white;"">erface(
61
+ fn=greet,
62
+ inputs=[""text""],
63
+ outputs=gr.HTML() # 修改为HTML输出组件
64
+ ) (Probability: 0.06)</span>",2025-01-30 11:44:36.020197
DetectAnyLLM ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 2d182abad5143fc1183cfedca2a30f58c3d44e7e
LICENSE ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Pi-Lab License 1.0
2
+
3
+ Copyright 2025 Pi-Lab
4
+
5
+ Redistribution and use for non-commercial purpose in source and
6
+ binary forms, with or without modification, are permitted provided
7
+ that the following conditions are met:
8
+
9
+ 1. Redistributions of source code must retain the above copyright
10
+ notice, this list of conditions and the following disclaimer.
11
+
12
+ 2. Redistributions in binary form must reproduce the above copyright
13
+ notice, this list of conditions and the following disclaimer in
14
+ the documentation and/or other materials provided with the
15
+ distribution.
16
+
17
+ 3. Neither the name of the copyright holder nor the names of its
18
+ contributors may be used to endorse or promote products derived
19
+ from this software without specific prior written permission.
20
+
21
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22
+ "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25
+ HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26
+ SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27
+ LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28
+ DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29
+ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32
+
33
+ In the event that redistribution and/or use for commercial purpose in
34
+ source or binary forms, with or without modification is required,
35
+ please contact the contributor(s) of the work.
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: DetectAnyLLM
3
+ emoji: 🔥
4
+ colorFrom: pink
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 5.46.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: other
11
+ short_description: '[ACMMM 2025] State-Of-The-Art AI-Text Detector'
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ import os
5
+ import json
6
+ from core.model import DiscrepancyEstimator
7
+ import re
8
+ import docx
9
+ import spaces
10
+ from datasets import load_dataset
11
+
12
+
13
+ def read_file_content(file):
14
+ if file is None:
15
+ return ""
16
+ if file.name.endswith('.txt'):
17
+ with open(file.name, 'r', encoding='utf-8') as f:
18
+ return f.read()
19
+ elif file.name.endswith('.docx'):
20
+ doc = docx.Document(file.name)
21
+ full_text = []
22
+ for para in doc.paragraphs:
23
+ full_text.append(para.text)
24
+ return '\n'.join(full_text)
25
+ return ""
26
+
27
+ def split_sentences(text):
28
+ """根据句号、句点、分号分割文本成句子,同时保留分句符号。"""
29
+ sentences = re.split(r'([。.])', text)
30
+ combined_sentences = [sentences[i] + sentences[i+1] for i in range(0, len(sentences)-1, 2)]
31
+ if len(sentences) % 2 == 1:
32
+ combined_sentences.append(sentences[-1])
33
+ return [s.strip() for s in combined_sentences if s.strip()]
34
+
35
+ def count_words(sentence, language='Chinese'):
36
+ """统计句子的词数。"""
37
+ return len(sentence.replace('\n', '').replace('\r', '').split()) if language != 'Chinese' else len(sentence.replace('\n', '').replace('\r', ''))
38
+
39
+ def segment_text(sentences, language='Chinese'):
40
+ """按照要求拼接句子,确保不忽略第一段并处理最后一句话不足100词的情况。"""
41
+ result = []
42
+ current_segment = []
43
+ current_length = 0
44
+
45
+ for i, sentence in enumerate(sentences):
46
+ word_count = count_words(sentence, language)
47
+
48
+ if word_count > 100:
49
+ # 如果单个句子超过100词,考虑拼接
50
+ if i + 1 < len(sentences) and word_count + count_words(sentences[i + 1], language) <= 200:
51
+ # 拼接当前和下一个句子
52
+ if current_segment: # 先保存当前段
53
+ result.append(' '.join(current_segment) if language != 'Chinese' else ''.join(current_segment))
54
+ result.append((sentence + ' ' + sentences[i + 1]) if language != 'Chinese' else (sentence + sentences[i + 1]))
55
+ current_segment = []
56
+ current_length = 0
57
+ i += 1 # 跳过下一个句子
58
+ continue
59
+ else:
60
+ # 单独存放
61
+ if current_segment: # 先保存当前段
62
+ result.append(' '.join(current_segment) if language != 'Chinese' else ''.join(current_segment))
63
+ result.append(sentence)
64
+ current_segment = []
65
+ current_length = 0
66
+ else:
67
+ if current_length + word_count > 100:
68
+ # 当前段超过100词,保存并开始新段
69
+ if current_segment:
70
+ result.append(' '.join(current_segment) if language != 'Chinese' else ''.join(current_segment))
71
+ current_segment = [sentence]
72
+ current_length = word_count
73
+ else:
74
+ # 继续累积
75
+ current_segment.append(sentence)
76
+ current_length += word_count
77
+
78
+ # 处理最后一段
79
+ if current_segment:
80
+ if current_length < 100 and result and current_length + count_words(result[-1], language) <= 200:
81
+ # 如果最后一段不足100词,且可以与前一段合并
82
+ last_segment = result.pop() if result else ''
83
+ current_segment = (last_segment.split() if language != 'Chinese' else list(last_segment)) + current_segment
84
+ result.append(' '.join(current_segment) if language != 'Chinese' else ''.join(current_segment))
85
+ else:
86
+ # 直接添加最后一段
87
+ result.append(' '.join(current_segment) if language != 'Chinese' else ''.join(current_segment))
88
+
89
+ return result
90
+
91
+ def extract_latex_text(latex_source):
92
+ # 提取document环境中的内容
93
+ doc_pattern = re.compile(r'\\begin{document}(.*?)\\end{document}', re.DOTALL)
94
+ match = doc_pattern.search(latex_source)
95
+ content = match.group(1) if match else latex_source
96
+
97
+ # 删除注释(排除转义后的%)
98
+ content = re.sub(r'(?<!\\)%.*', '', content, flags=re.MULTILINE)
99
+
100
+ # 排除常见非文本环境
101
+ excluded_envs = ['figure', 'table', 'equation', 'align\*?', 'verbatim', 'lstlisting']
102
+ env_pattern = re.compile(
103
+ r'\\begin{(' + '|'.join(excluded_envs) + r')}.*?\\end{\1}',
104
+ re.DOTALL
105
+ )
106
+ content = env_pattern.sub('', content)
107
+
108
+ # 新增处理:删除所有cite命令及其内容
109
+ content = re.sub(r'\\cite(\[[^\]]*\])?\{[^}]*\}', '', content)
110
+
111
+ # 新增处理:删除行内table/figure命令及其内容
112
+ content = re.sub(r'\\(table|figure)\*?(\[[^\]]*\])?\{[^}]*\}', '', content)
113
+
114
+ # 删除简单命令(无参数)
115
+ content = re.sub(r'\\([a-zA-Z]+)\*?\b', '', content)
116
+
117
+ # 递归处理带参数的命令(最多迭代10次防止死循环)
118
+ for _ in range(10):
119
+ new_content = re.sub(
120
+ r'\\([a-zA-Z]+)\*?(?:\[.*?\])*{((?:[^{}]*|{[^{}]*})*)}',
121
+ lambda m: m.group(2),
122
+ content,
123
+ flags=re.DOTALL
124
+ )
125
+ if new_content == content:
126
+ break
127
+ content = new_content
128
+
129
+ # 处理特殊字符
130
+ replacements = {
131
+ '~': ' ', '\\&': '&', '\\$': '$', '\\%': '%',
132
+ '\\_': '_', '\\#': '#', '\\\\': '\n', '\n': ' ',
133
+ '“': '"', '”': '"', '‘': "'", '’': "'"
134
+ }
135
+ for k, v in replacements.items():
136
+ content = content.replace(k, v)
137
+
138
+ # 清理空白字符
139
+ content = re.sub(r'[ \t]+', ' ', content)
140
+ content = re.sub(r'\n{2,}', '\n\n', content)
141
+ return content.strip()
142
+
143
+ class ProbEstimator:
144
+ def __init__(self, ref_file_dir):
145
+ self.tasks = ["polish", "generate", "rewrite"]
146
+ self.real_crits = {"polish": [], "generate": [], "rewrite": []}
147
+ self.fake_crits = {"polish": [], "generate": [], "rewrite": []}
148
+ for task in self.tasks:
149
+ task_ref_data = load_dataset(ref_file_dir, data_files=f'{task}.json')['train']
150
+ self.real_crits[task].extend(task_ref_data['original_discrepancy'])
151
+ self.fake_crits[task].extend(task_ref_data['rewritten_discrepancy'])
152
+ print(f'ProbEstimator: total {sum([len(self.real_crits[task]) for task in self.tasks]) * 2} samples.')
153
+
154
+ def crit_to_prob(self, crit):
155
+ probs = {}
156
+ for task in self.tasks:
157
+ real_crits = self.real_crits[task]
158
+ fake_crits = self.fake_crits[task]
159
+ total_len = len(real_crits) + len(fake_crits)
160
+ offset = np.sort(np.abs(np.array(real_crits + fake_crits) - crit))[int(0.1*total_len)]
161
+ cnt_real = np.sum((np.array(real_crits) > crit - offset) & (np.array(real_crits) < crit + offset))
162
+ cnt_fake = np.sum((np.array(fake_crits) > crit - offset) & (np.array(fake_crits) < crit + offset))
163
+ probs[task] = (cnt_fake / (cnt_real + cnt_fake)) if (cnt_real + cnt_fake) > 0 else 0.5
164
+ return probs
165
+
166
+ device = 'cuda'
167
+ zh_prob_estimator = ProbEstimator(ref_file_dir="JiachenFu/Qwen2-0.5B-detectanyllm-detector-ref-zh")
168
+ en_prob_estimator = ProbEstimator(ref_file_dir="JiachenFu/Qwen2-0.5B-detectanyllm-detector-ref-en")
169
+
170
+ @spaces.GPU
171
+ def greet(mode, language, input_text):
172
+ if mode == "LaTex":
173
+ input_text = extract_latex_text(input_text)
174
+ split_texts = split_sentences(input_text)
175
+ sub_texts = segment_text(split_texts, language=language)
176
+ detected = []
177
+ if language == "Chinese":
178
+ model = DiscrepancyEstimator(load_directory="JiachenFu/Qwen2-0.5B-detectanyllm-detector-zh").to(device)
179
+ prob_estimator = zh_prob_estimator
180
+ else:
181
+ model = DiscrepancyEstimator(load_directory="JiachenFu/Qwen2-0.5B-detectanyllm-detector-en").to(device)
182
+ prob_estimator = en_prob_estimator
183
+ model.eval()
184
+ for i, sub_text in enumerate(sub_texts):
185
+ text_content = sub_text
186
+ print(f'processing {sub_text}')
187
+ tokens = model.scoring_tokenizer(
188
+ text_content, return_tensors='pt', padding=True, truncation=True, return_token_type_ids=False
189
+ )
190
+ print(f'tokenized')
191
+ input_ids = tokens['input_ids'].to(device)
192
+ attention_mask = tokens['attention_mask'].to(device)
193
+ with torch.no_grad():
194
+ output = model.get_discrepancy_of_scoring_and_reference_models(
195
+ input_ids_for_scoring_model=input_ids,
196
+ attention_mask_for_scoring_model=attention_mask,
197
+ input_ids_for_reference_model=None,
198
+ attention_mask_for_reference_model=None,
199
+ )
200
+ discrepancy = output['scoring_discrepancy']
201
+ discrepancy = discrepancy.cpu().numpy().item()
202
+ print(f'discrepancy: {discrepancy}')
203
+ probs = prob_estimator.crit_to_prob(discrepancy)
204
+ if discrepancy < 15:
205
+ for task in probs.keys():
206
+ probs[task] = 0.0
207
+ detected.append({
208
+ 'order': i,
209
+ 'text': text_content,
210
+ 'words_count': len(text_content) if language == "Chinese" else len(text_content.split()),
211
+ 'probs': probs
212
+ })
213
+
214
+ # 添加绝对定位的总概率显示
215
+ # 构建动画效果
216
+ html_output = '''
217
+ <style>
218
+ @keyframes reveal {
219
+ from { opacity: 0; }
220
+ to { opacity: 1; }
221
+ }
222
+ .reveal-char {
223
+ opacity: 0;
224
+ animation: reveal 0.2s forwards;
225
+ white-space: pre-wrap;
226
+ }
227
+ </style>
228
+ <div style="position: relative; padding-bottom: 60px; min-height: 120px;">
229
+ '''
230
+
231
+ current_delay = 0.0 # 当前动画延迟时间
232
+ char_duration = 0.001 # 每个字符的间隔时间
233
+
234
+ # 处理文本内容
235
+ for item in detected:
236
+ ai_generate_prob = item['probs']['generate']
237
+ ai_revise_prob = max(item['probs']['polish'], item['probs']['rewrite'])
238
+ prob = max(ai_generate_prob, ai_revise_prob)
239
+ if prob >= 0.75:
240
+ if ai_generate_prob >= ai_revise_prob:
241
+ color = "red"
242
+ item["generate"] = 1
243
+ item["revise"] = 0
244
+ else:
245
+ color = "orange"
246
+ item["generate"] = 0
247
+ item["revise"] = 1
248
+ else:
249
+ color = "black"
250
+ item["generate"] = 0
251
+ item["revise"] = 0
252
+
253
+ for char in item['text']:
254
+ html_output += f'<span class="reveal-char" style="color: {color}; animation-delay: {current_delay:.2f}s;">{char}</span>'
255
+ current_delay += char_duration
256
+
257
+ total_length = sum(item['words_count'] for item in detected)
258
+ # total_prob = sum(item['prob'] * item['words_count'] for item in detected) / total_length if total_length > 0 else 0
259
+ generate_prob = sum(item["generate"] * item["words_count"] for item in detected) / total_length if total_length > 0 else 0
260
+ revise_prob = sum(item["revise"] * item["words_count"] for item in detected) / total_length if total_length > 0 else 0
261
+ html_output += f'''
262
+ <div style="
263
+ position: absolute;
264
+ bottom: 0;
265
+ right: 0;
266
+ background-color: rgba(255, 255, 255, 0.9);
267
+ padding: 8px 12px;
268
+ border-radius: 4px;
269
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
270
+ border: 1px solid #e0e0e0;
271
+ font-size: 14px;
272
+ ">
273
+ 🤖 AI Generated Rate: <strong>{generate_prob:.2%}</strong><br>
274
+ ✍️ AI Revised Rate: <strong>{revise_prob:.2%}</strong>
275
+ </div>
276
+ '''
277
+
278
+ html_output += '</div>'
279
+ return html_output
280
+
281
+ # 使用Blocks替代Interface以获得更好的自定义能力
282
+ # 修改CSS部分
283
+ with gr.Blocks(css="""
284
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700;800&display=swap');
285
+
286
+ :root {
287
+ --accent-color: #6366f1;
288
+ --text-color: #374151;
289
+ --border-color: #e5e7eb;
290
+ --background-light: #f9fafb;
291
+ --background-card: #ffffff;
292
+ }
293
+
294
+ body, .gradio-container {
295
+ background: var(--background-light);
296
+ font-family: 'Inter', sans-serif;
297
+ color: var(--text-color);
298
+ }
299
+
300
+ #header {
301
+ text-align: center;
302
+ padding: 2rem;
303
+ margin: 0 auto; /* Use gap for spacing, remove margin-bottom */
304
+ background-color: var(--background-card);
305
+ background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='40' height='40' viewBox='0 0 40 40'%3E%3Cg fill-rule='evenodd'%3E%3Cg fill='%23e5e7eb' fill-opacity='0.3'%3E%3Cpath d='M0 38.59l2.83-2.83 1.41 1.41L1.41 40H0v-1.41zM0 1.4l2.83 2.83 1.41-1.41L1.41 0H0v1.41zM38.59 40l-2.83-2.83 1.41-1.41L40 38.59V40h-1.41zM40 1.41l-2.83 2.83-1.41-1.41L38.59 0H40v1.41zM20 18.6l2.83-2.83 1.41 1.41L21.41 20l2.83 2.83-1.41 1.41L20 21.41l-2.83 2.83-1.41-1.41L18.59 20l-2.83-2.83 1.41-1.41L20 18.59z'/%3E%3C/g%3E%3C/g%3E%3C/svg%3E");
306
+ border: 1px solid var(--border-color);
307
+ border-radius: 16px;
308
+ box-shadow: 0 4px 12px rgba(0,0,0,0.05);
309
+ }
310
+ #title {
311
+ font-weight: 800;
312
+ font-size: 2.5em;
313
+ letter-spacing: -0.02em;
314
+ color: var(--text-color);
315
+ margin-bottom: 0.25em;
316
+ }
317
+ .detect-grad {
318
+ background: -webkit-linear-gradient(left, #ff8c8c, #ffc89e);
319
+ -webkit-background-clip: text;
320
+ -webkit-text-fill-color: transparent;
321
+ font-weight: 800;
322
+ }
323
+ .anyllm-grad {
324
+ background: -webkit-linear-gradient(left, #a0e6ff, #aaffd4);
325
+ -webkit-background-clip: text;
326
+ -webkit-text-fill-color: transparent;
327
+ font-weight: 800;
328
+ }
329
+ #authors {
330
+ font-size: 1.1em;
331
+ color: #6b7280;
332
+ margin: 0;
333
+ }
334
+
335
+ #main-container {
336
+ max-width: 1200px;
337
+ margin: 0 auto;
338
+ padding: 0 1rem;
339
+ gap: 2rem; /* Add gap for consistent spacing */
340
+ }
341
+
342
+ #controls-row {
343
+ justify-content: center;
344
+ gap: 2rem;
345
+ }
346
+
347
+ /* Custom styles for Radio Button Groups */
348
+ #controls-row > div {
349
+ background-color: var(--background-card);
350
+ background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='40' height='40' viewBox='0 0 40 40'%3E%3Cg fill-rule='evenodd'%3E%3Cg fill='%23e5e7eb' fill-opacity='0.3'%3E%3Cpath d='M0 38.59l2.83-2.83 1.41 1.41L1.41 40H0v-1.41zM0 1.4l2.83 2.83 1.41-1.41L1.41 0H0v1.41zM38.59 40l-2.83-2.83 1.41-1.41L40 38.59V40h-1.41zM40 1.41l-2.83 2.83-1.41-1.41L38.59 0H40v1.41zM20 18.6l2.83-2.83 1.41 1.41L21.41 20l2.83 2.83-1.41 1.41L20 21.41l-2.83 2.83-1.41-1.41L18.59 20l-2.83-2.83 1.41-1.41L20 18.59z'/%3E%3C/g%3E%3C/g%3E%3C/svg%3E");
351
+ border: 1px solid var(--border-color);
352
+ border-radius: 16px;
353
+ padding: 1rem;
354
+ box-shadow: 0 4px 12px rgba(0,0,0,0.05);
355
+ }
356
+
357
+ #controls-row .gradio-button {
358
+ border-radius: 10px !important;
359
+ transition: background-color 0.2s ease, color 0.2s ease;
360
+ }
361
+
362
+ #controls-row .gradio-button.selected {
363
+ background: var(--accent-color) !important;
364
+ color: white !important;
365
+ border-color: var(--accent-color) !important;
366
+ }
367
+
368
+ #content-row {
369
+ gap: 1.5rem;
370
+ }
371
+
372
+ .card {
373
+ background-color: var(--background-card);
374
+ background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='40' height='40' viewBox='0 0 40 40'%3E%3Cg fill-rule='evenodd'%3E%3Cg fill='%23e5e7eb' fill-opacity='0.3'%3E%3Cpath d='M0 38.59l2.83-2.83 1.41 1.41L1.41 40H0v-1.41zM0 1.4l2.83 2.83 1.41-1.41L1.41 0H0v1.41zM38.59 40l-2.83-2.83 1.41-1.41L40 38.59V40h-1.41zM40 1.41l-2.83 2.83-1.41-1.41L38.59 0H40v1.41zM20 18.6l2.83-2.83 1.41 1.41L21.41 20l2.83 2.83-1.41 1.41L20 21.41l-2.83 2.83-1.41-1.41L18.59 20l-2.83-2.83 1.41-1.41L20 18.59z'/%3E%3C/g%3E%3C/g%3E%3C/svg%3E");
375
+ border: 1px solid var(--border-color);
376
+ border-radius: 16px;
377
+ padding: 1.5rem;
378
+ box-shadow: 0 4px 12px rgba(0,0,0,0.05);
379
+ height: 100%;
380
+ display: flex;
381
+ flex-direction: column;
382
+ gap: 1rem;
383
+ }
384
+
385
+ .card-title {
386
+ font-weight: 600;
387
+ font-size: 1.2rem;
388
+ color: var(--text-color);
389
+ padding-bottom: 0.75rem;
390
+ border-bottom: 1px solid var(--border-color);
391
+ }
392
+
393
+ #input-text textarea {
394
+ flex-grow: 1;
395
+ border: none !important;
396
+ box-shadow: none !important;
397
+ padding: 0 !important;
398
+ font-size: 1.1em;
399
+ line-height: 1.7;
400
+ }
401
+
402
+ #result-html {
403
+ flex-grow: 1;
404
+ font-size: 1.1em;
405
+ line-height: 1.7;
406
+ overflow-y: auto;
407
+ height: 520px;
408
+ }
409
+
410
+ #input-footer {
411
+ display: flex;
412
+ justify-content: space-between;
413
+ align-items: center;
414
+ margin-top: auto; /* Push to bottom */
415
+ }
416
+
417
+ #char-counter {
418
+ font-size: 0.9em;
419
+ color: #9ca3af;
420
+ }
421
+ #char-counter.error {
422
+ color: #ef4444;
423
+ }
424
+
425
+ #submit-btn {
426
+ flex-grow: 1;
427
+ max-width: 200px;
428
+ font-size: 1.05em;
429
+ font-weight: 600;
430
+ background: var(--accent-color);
431
+ color: white;
432
+ border-radius: 10px;
433
+ }
434
+ #submit-btn:hover {
435
+ background: #4f46e5;
436
+ }
437
+
438
+ .disclaimer {
439
+ text-align: center;
440
+ margin: 0 auto; /* Remove vertical margins */
441
+ color: #64748b;
442
+ font-size: 1.1em;
443
+ max-width: 800px;
444
+ }
445
+ /* Reveal 动画更丝滑 */
446
+ @keyframes reveal {
447
+ from { opacity: 0; }
448
+ to { opacity: 1; }
449
+ }
450
+ .reveal-char {
451
+ opacity: 0;
452
+ animation: reveal 0.2s forwards;
453
+ white-space: pre-wrap;
454
+ }
455
+ """) as demo:
456
+ with gr.Column(elem_id="main-container"):
457
+ gr.Markdown("""
458
+ <div id="header">
459
+ <h1 id="title"><span class="detect-grad">Detect</span><span class="anyllm-grad">AnyLLM</span>: Towards Generalizable and Robust Detection of Machine-Generated Text Across Domains and Models</h1>
460
+ <p id="authors">Jiachen Fu, Chun-Le Guo, Chongyi Li</p>
461
+ </div>
462
+ """)
463
+
464
+ with gr.Row(elem_id="controls-row"):
465
+ language_radio = gr.Radio(
466
+ choices=["English", "Chinese"],
467
+ value="English",
468
+ label="🌐 Language",
469
+ interactive=True
470
+ )
471
+ mode_radio = gr.Radio(
472
+ choices=["Text-Only", "LaTex"],
473
+ value="Text-Only",
474
+ label="✍️ Input Type",
475
+ interactive=True
476
+ )
477
+
478
+ with gr.Row(equal_height=True, elem_id="content-row"):
479
+ with gr.Column(scale=1, min_width=500):
480
+ with gr.Column(elem_classes="card"):
481
+ gr.HTML('<div class="card-title">📝 Input</div>')
482
+ upload_btn = gr.File(
483
+ label="Upload File (txt, docx)",
484
+ file_types=['.txt', '.docx'],
485
+ elem_id="upload-btn"
486
+ )
487
+ input_text = gr.Textbox(
488
+ show_label=False,
489
+ placeholder="Enter text to detect or upload a file...",
490
+ lines=15,
491
+ elem_id="input-text",
492
+ max_length=100000,
493
+ )
494
+ with gr.Row(elem_id="input-footer"):
495
+ counter_html = gr.HTML("<div id='char-counter'>0/100000</div>")
496
+ submit_btn = gr.Button("✨ Detect", variant="primary", elem_id="submit-btn")
497
+
498
+ with gr.Column(scale=1, min_width=500):
499
+ with gr.Column(elem_classes="card"):
500
+ gr.HTML('<div class="card-title">🔍 Result</div>')
501
+ result = gr.HTML(elem_id="result-html")
502
+
503
+ gr.HTML("""
504
+ <div class="disclaimer">
505
+ 💡 <i><b style="color: red;">Red fonts</b> indicate a high probability of AI generation. <b style="color: orange;">Orange fonts</b> indicate a high probability of AI revision or polishing. The detection results are for reference only.</i>
506
+ </div>
507
+ """)
508
+
509
+ upload_btn.upload(
510
+ read_file_content,
511
+ inputs=upload_btn,
512
+ outputs=input_text
513
+ )
514
+
515
+ input_text.input(
516
+ None,
517
+ [input_text],
518
+ None,
519
+ js="""
520
+ (text) => {
521
+ setTimeout(() => {
522
+ const counter = document.getElementById("char-counter");
523
+ if (counter) {
524
+ const length = text.length;
525
+ counter.innerHTML = `${length}/100000`;
526
+ counter.classList.toggle("error", length > 100000);
527
+ }
528
+ }, 0);
529
+ return text;
530
+ }
531
+ """
532
+ )
533
+ submit_btn.click(
534
+ greet,
535
+ inputs=[mode_radio, language_radio, input_text],
536
+ outputs=result
537
+ )
538
+
539
+ demo.launch(share=True)
core/model.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import sys
5
+ import os
6
+ import time
7
+ import copy
8
+ from peft import get_peft_model, LoraConfig, TaskType, AutoPeftModelForCausalLM
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
+
11
+ def from_pretrained(cls, model_name, kwargs, cache_dir):
12
+ # use local model if it exists
13
+ if "/" in model_name:
14
+ local_path = os.path.join(cache_dir, model_name.split("/")[-1])
15
+ else:
16
+ local_path = os.path.join(cache_dir, model_name)
17
+
18
+ if os.path.exists(local_path):
19
+ return cls.from_pretrained(local_path, **kwargs)
20
+ return cls.from_pretrained(model_name, **kwargs, cache_dir=cache_dir, device_map='auto')
21
+
22
+
23
+ class DiscrepancyEstimator(nn.Module):
24
+ def __init__(self,
25
+ scoring_model_name: str=None,
26
+ reference_model_name: str=None,
27
+ scoring_model: AutoModelForCausalLM=None,
28
+ reference_model: AutoModelForCausalLM=None,
29
+ scoring_tokenizer: AutoTokenizer=None,
30
+ reference_tokenizer: AutoTokenizer=None,
31
+ cache_dir: str=None,
32
+ train_method: str='DDL',
33
+ pretrained_ckpt: str=None,
34
+ ):
35
+ super().__init__()
36
+ assert train_method in ['DDL', 'SPO'], 'train_method should be DDL or SPO.'
37
+ self.train_method = train_method
38
+ self.cache_dir = cache_dir
39
+ if pretrained_ckpt is not None:
40
+ self.load_pretrained(pretrained_ckpt)
41
+ else:
42
+ if scoring_model_name is not None:
43
+ if 'gpt-j' in scoring_model_name or 'GPT-J' in scoring_model_name:
44
+ model_kwargs = dict(
45
+ torch_dtype=torch.float16,
46
+ revision='float16'
47
+ )
48
+ else:
49
+ model_kwargs = {}
50
+ self.scoring_model_name = scoring_model_name
51
+ self.scoring_model = from_pretrained(AutoModelForCausalLM,
52
+ scoring_model_name,
53
+ cache_dir=cache_dir,
54
+ kwargs=model_kwargs)
55
+ self.scoring_tokenizer = from_pretrained(AutoTokenizer,
56
+ scoring_model_name,
57
+ kwargs={'padding_side': 'right',
58
+ 'use_fast': True if 'facebook/opt-' not in scoring_model_name else False},
59
+ cache_dir=cache_dir,)
60
+ else:
61
+ if scoring_model is None or scoring_tokenizer is None:
62
+ raise ValueError('You should provide scoring_model_name or scoring_model and scoring_tokenizer.')
63
+ self.scoring_model = scoring_model
64
+ self.scoring_tokenizer = scoring_tokenizer
65
+ self.scoring_model_name = scoring_model.config._name_or_path
66
+ if self.scoring_tokenizer.pad_token is None:
67
+ self.scoring_tokenizer.pad_token = self.scoring_tokenizer.eos_token
68
+ self.scoring_tokenizer.pad_token_id = self.scoring_tokenizer.eos_token_id
69
+
70
+ if reference_model_name is not None:
71
+ if 'gpt-j' in reference_model_name or 'GPT-J' in reference_model_name:
72
+ model_kwargs = dict(
73
+ torch_dtype=torch.float16,
74
+ revision='float16'
75
+ )
76
+ else:
77
+ model_kwargs = {}
78
+ self.reference_model = from_pretrained(AutoModelForCausalLM,
79
+ reference_model_name,
80
+ cache_dir=cache_dir,
81
+ kwargs=model_kwargs)
82
+ self.reference_tokenizer = from_pretrained(AutoTokenizer,
83
+ reference_model_name,
84
+ kwargs={'padding_side': 'right',
85
+ 'use_fast': True if 'facebook/opt-' not in reference_model_name else False},
86
+ cache_dir=cache_dir,)
87
+ self.reference_model_name = reference_model_name
88
+ else:
89
+ if reference_model is None and reference_tokenizer is None:
90
+ if train_method == 'DDL':
91
+ self.reference_model = None
92
+ self.reference_tokenizer = None
93
+ self.reference_model_name = None
94
+ else:
95
+ self.reference_model = copy.deepcopy(self.scoring_model)
96
+ self.reference_tokenizer = self.scoring_tokenizer
97
+ self.reference_model_name = self.reference_model.config._name_or_path
98
+ elif reference_model is not None and reference_tokenizer is not None:
99
+ self.reference_model = reference_model
100
+ self.reference_tokenizer = reference_tokenizer
101
+ self.reference_model_name = reference_model.config._name_or_path
102
+ else:
103
+ raise ValueError('You should provide reference_model and reference_tokenizer at the same time.')
104
+
105
+ if self.reference_tokenizer is not None:
106
+ if self.reference_tokenizer.pad_token is None:
107
+ self.reference_tokenizer.pad_token = self.reference_tokenizer.eos_token
108
+ self.reference_tokenizer.pad_token_id = self.reference_tokenizer.eos_token_id
109
+
110
+ def add_lora_config(self, lora_config: LoraConfig):
111
+ self.lora_config = lora_config
112
+ self.scoring_model = get_peft_model(self.scoring_model, self.lora_config)
113
+
114
+ def load_pretrained(self, load_directory, load_directory_ref=None):
115
+ """
116
+ Load the model's state_dict from the specified directory.
117
+ """
118
+ if not os.path.exists(load_directory):
119
+ raise ValueError(f"Directory {load_directory} does not exist.")
120
+
121
+ if 'gpt-j' in load_directory or 'GPT-J' in load_directory:
122
+ model_kwargs = dict(
123
+ torch_dtype=torch.float16,
124
+ revision='float16'
125
+ )
126
+ else:
127
+ model_kwargs = {}
128
+
129
+ self.scoring_model = AutoPeftModelForCausalLM.from_pretrained(load_directory, **model_kwargs)
130
+ self.scoring_tokenizer = AutoTokenizer.from_pretrained(load_directory)
131
+ self.scoring_model_name = self.scoring_model.config._name_or_path
132
+
133
+ if load_directory_ref:
134
+ self.reference_model = AutoModelForCausalLM.from_pretrained(load_directory_ref, **model_kwargs)
135
+ self.reference_tokenizer = AutoTokenizer.from_pretrained(load_directory_ref)
136
+ self.reference_model_name = self.reference_model.config._name_or_path
137
+ else:
138
+ self.reference_model = None
139
+ self.reference_tokenizer = None
140
+ self.reference_model_name = None
141
+
142
+ if self.scoring_tokenizer.pad_token is None:
143
+ self.scoring_tokenizer.pad_token = self.scoring_tokenizer.eos_token
144
+ self.scoring_tokenizer.pad_token_id = self.scoring_tokenizer.eos_token_id
145
+ if self.reference_tokenizer is not None:
146
+ if self.reference_tokenizer.pad_token is None:
147
+ self.reference_tokenizer.pad_token = self.reference_tokenizer.eos_token
148
+ self.reference_tokenizer.pad_token_id = self.reference_tokenizer.eos_token_id
149
+
150
+
151
+ def get_sampling_discrepancy_analytic(self, reference_logits, scoring_logits, labels, attention_mask):
152
+
153
+ if reference_logits.size(-1) != scoring_logits.size(-1):
154
+ vocab_size = min(reference_logits.size(-1), scoring_logits.size(-1))
155
+ reference_logits = reference_logits[:, :, :vocab_size]
156
+ scoring_logits = scoring_logits[:, :, :vocab_size]
157
+
158
+ labels = labels.unsqueeze(-1) if labels.ndim == scoring_logits.ndim - 1 else labels
159
+ lprobs_score = torch.log_softmax(scoring_logits, dim=-1)
160
+ probs_ref = torch.softmax(reference_logits, dim=-1)
161
+
162
+ log_likelihood = lprobs_score.gather(dim=-1, index=labels).squeeze(-1)
163
+ mean_ref = (probs_ref * lprobs_score).sum(dim=-1)
164
+ var_ref = (probs_ref * torch.square(lprobs_score)).sum(dim=-1) - torch.square(mean_ref)
165
+
166
+ mask = attention_mask[:, 1:].float() # [bsz, seq_len-1], 1 for non-pad, 0 for pad
167
+ log_likelihood_sum = (log_likelihood * mask).sum(dim=-1) # [bsz], sum over non-pad tokens
168
+ mean_ref_sum = (mean_ref * mask).sum(dim=-1) # [bsz], sum over non-pad tokens
169
+ var_ref_sum = (var_ref * mask).sum(dim=-1) # [bsz], sum over non-pad tokens
170
+ discrepancy = (log_likelihood_sum - mean_ref_sum) / (var_ref_sum.sqrt() + 1e-8) # [bsz], avoid division by zero
171
+
172
+ return discrepancy, log_likelihood_sum
173
+
174
+ def get_discrepancy_of_scoring_and_reference_models(self,
175
+ input_ids_for_scoring_model,
176
+ attention_mask_for_scoring_model,
177
+ input_ids_for_reference_model=None,
178
+ attention_mask_for_reference_model=None,
179
+ ) -> dict:
180
+ labels = input_ids_for_scoring_model[:, 1:] # shape: [bsz, sentence_len - 1]
181
+ scoring_logits = self.scoring_model(input_ids_for_scoring_model,
182
+ attention_mask=attention_mask_for_scoring_model).logits[:,:-1,:]
183
+ if self.reference_model is not None:
184
+ assert input_ids_for_reference_model is not None and attention_mask_for_reference_model is not None, \
185
+ "If reference_model is provided, you should provide reference_tokenizer to dataset initialization."
186
+ with torch.no_grad():
187
+ # check if tokenizer is the match
188
+ reference_labels = input_ids_for_reference_model[:, 1:] # shape: [bsz, sentence_len]
189
+ assert torch.all(reference_labels == labels), \
190
+ "Tokenizer is mismatch."
191
+ reference_logits = self.reference_model(input_ids_for_reference_model,
192
+ attention_mask=attention_mask_for_reference_model).logits[:,:-1,:]
193
+ else:
194
+ reference_logits = scoring_logits
195
+
196
+ if self.reference_model is not None:
197
+ discrepancy_ref, logprob_ref = self.get_sampling_discrepancy_analytic(reference_logits, reference_logits,
198
+ labels, attention_mask=attention_mask_for_reference_model)
199
+ else:
200
+ discrepancy_ref, logprob_ref = None, None
201
+ discrepancy_score, logprob_score = self.get_sampling_discrepancy_analytic(reference_logits, scoring_logits,
202
+ labels, attention_mask=attention_mask_for_scoring_model)
203
+
204
+ return {
205
+ 'scoring_discrepancy': discrepancy_score,
206
+ 'scoring_logprob': logprob_score,
207
+ 'reference_discrepancy': discrepancy_ref,
208
+ 'reference_logprob': logprob_ref,
209
+ }
210
+
211
+ def forward(self,
212
+ scoring_original_input_ids,
213
+ scoring_original_attention_mask,
214
+ scoring_rewritten_input_ids,
215
+ scoring_rewritten_attention_mask,
216
+ reference_original_input_ids=None,
217
+ reference_original_attention_mask=None,
218
+ reference_rewritten_input_ids=None,
219
+ reference_rewritten_attention_mask=None,
220
+ ) -> dict:
221
+ if self.train_method == 'SPO':
222
+ assert reference_original_input_ids is not None and reference_original_attention_mask is not None, \
223
+ "If train_method is SPO, you should provide reference_original_input_ids and reference_original_attention_mask."
224
+ assert reference_rewritten_input_ids is not None and reference_rewritten_attention_mask is not None, \
225
+ "If train_method is SPO, you should provide reference_rewritten_input_ids and reference_rewritten_attention_mask."
226
+ elif self.train_method == 'DDL':
227
+ assert reference_original_input_ids is None and reference_original_attention_mask is None, \
228
+ "If train_method is DDL, you should not provide reference_original_input_ids and reference_original_attention_mask."
229
+ assert reference_rewritten_input_ids is None and reference_rewritten_attention_mask is None, \
230
+ "If train_method is DDL, you should not provide reference_rewritten_input_ids and reference_rewritten_attention_mask."
231
+ else:
232
+ raise ValueError('train_method should be DDL or SPO.')
233
+ original_output = self.get_discrepancy_of_scoring_and_reference_models(
234
+ input_ids_for_scoring_model=scoring_original_input_ids,
235
+ attention_mask_for_scoring_model=scoring_original_attention_mask,
236
+ input_ids_for_reference_model=reference_original_input_ids,
237
+ attention_mask_for_reference_model=reference_original_attention_mask,
238
+ )
239
+ rewritten_output = self.get_discrepancy_of_scoring_and_reference_models(
240
+ input_ids_for_scoring_model=scoring_rewritten_input_ids,
241
+ attention_mask_for_scoring_model=scoring_rewritten_attention_mask,
242
+ input_ids_for_reference_model=reference_rewritten_input_ids,
243
+ attention_mask_for_reference_model=reference_rewritten_attention_mask,
244
+ )
245
+
246
+ return {
247
+ 'scoring_original_discrepancy': original_output['scoring_discrepancy'],
248
+ 'scoring_original_logprob': original_output['scoring_logprob'],
249
+ 'scoring_rewritten_discrepancy': rewritten_output['scoring_discrepancy'],
250
+ 'scoring_rewritten_logprob': rewritten_output['scoring_logprob'],
251
+ 'reference_original_discrepancy': original_output['reference_discrepancy'],
252
+ 'reference_original_logprob': original_output['reference_logprob'],
253
+ 'reference_rewritten_discrepancy': rewritten_output['reference_discrepancy'],
254
+ 'reference_rewritten_logprob': rewritten_output['reference_logprob'],
255
+ }
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ peft
2
+ torch
3
+ transformers
4
+ protobuf
5
+ python-docx
6
+ gradio
7
+ numpy
8
+ huggingface_hub
9
+ datasets
10
+ spaces