Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
6600152
0
Parent(s):
update: app
Browse files- .gitignore +165 -0
- .gradio/certificate.pem +31 -0
- .gradio/flagged/dataset1.csv +64 -0
- DetectAnyLLM +1 -0
- LICENSE +35 -0
- README.md +14 -0
- app.py +539 -0
- core/model.py +255 -0
- requirements.txt +10 -0
.gitignore
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
ckpt/*/
|
| 12 |
+
logs/*/
|
| 13 |
+
models/*/
|
| 14 |
+
build/
|
| 15 |
+
develop-eggs/
|
| 16 |
+
dist/
|
| 17 |
+
downloads/
|
| 18 |
+
eggs/
|
| 19 |
+
.eggs/
|
| 20 |
+
lib/
|
| 21 |
+
lib64/
|
| 22 |
+
parts/
|
| 23 |
+
sdist/
|
| 24 |
+
var/
|
| 25 |
+
wheels/
|
| 26 |
+
share/python-wheels/
|
| 27 |
+
*.egg-info/
|
| 28 |
+
.installed.cfg
|
| 29 |
+
*.egg
|
| 30 |
+
MANIFEST
|
| 31 |
+
|
| 32 |
+
# PyInstaller
|
| 33 |
+
# Usually these files are written by a python script from a template
|
| 34 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 35 |
+
*.manifest
|
| 36 |
+
*.spec
|
| 37 |
+
|
| 38 |
+
# Installer logs
|
| 39 |
+
pip-log.txt
|
| 40 |
+
pip-delete-this-directory.txt
|
| 41 |
+
|
| 42 |
+
# Unit test / coverage reports
|
| 43 |
+
htmlcov/
|
| 44 |
+
.tox/
|
| 45 |
+
.nox/
|
| 46 |
+
.coverage
|
| 47 |
+
.coverage.*
|
| 48 |
+
.cache
|
| 49 |
+
nosetests.xml
|
| 50 |
+
coverage.xml
|
| 51 |
+
*.cover
|
| 52 |
+
*.py,cover
|
| 53 |
+
.hypothesis/
|
| 54 |
+
.pytest_cache/
|
| 55 |
+
cover/
|
| 56 |
+
|
| 57 |
+
# Translations
|
| 58 |
+
*.mo
|
| 59 |
+
*.pot
|
| 60 |
+
|
| 61 |
+
# Django stuff:
|
| 62 |
+
*.log
|
| 63 |
+
local_settings.py
|
| 64 |
+
db.sqlite3
|
| 65 |
+
db.sqlite3-journal
|
| 66 |
+
|
| 67 |
+
# Flask stuff:
|
| 68 |
+
instance/
|
| 69 |
+
.webassets-cache
|
| 70 |
+
|
| 71 |
+
# Scrapy stuff:
|
| 72 |
+
.scrapy
|
| 73 |
+
|
| 74 |
+
# Sphinx documentation
|
| 75 |
+
docs/_build/
|
| 76 |
+
|
| 77 |
+
# PyBuilder
|
| 78 |
+
.pybuilder/
|
| 79 |
+
target/
|
| 80 |
+
|
| 81 |
+
# Jupyter Notebook
|
| 82 |
+
.ipynb_checkpoints
|
| 83 |
+
|
| 84 |
+
# IPython
|
| 85 |
+
profile_default/
|
| 86 |
+
ipython_config.py
|
| 87 |
+
|
| 88 |
+
# pyenv
|
| 89 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 90 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 91 |
+
# .python-version
|
| 92 |
+
|
| 93 |
+
# pipenv
|
| 94 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 95 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 96 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 97 |
+
# install all needed dependencies.
|
| 98 |
+
#Pipfile.lock
|
| 99 |
+
|
| 100 |
+
# poetry
|
| 101 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 102 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 103 |
+
# commonly ignored for libraries.
|
| 104 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 105 |
+
#poetry.lock
|
| 106 |
+
|
| 107 |
+
# pdm
|
| 108 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 109 |
+
#pdm.lock
|
| 110 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 111 |
+
# in version control.
|
| 112 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
| 113 |
+
.pdm.toml
|
| 114 |
+
.pdm-python
|
| 115 |
+
.pdm-build/
|
| 116 |
+
|
| 117 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 118 |
+
__pypackages__/
|
| 119 |
+
|
| 120 |
+
# Celery stuff
|
| 121 |
+
celerybeat-schedule
|
| 122 |
+
celerybeat.pid
|
| 123 |
+
|
| 124 |
+
# SageMath parsed files
|
| 125 |
+
*.sage.py
|
| 126 |
+
|
| 127 |
+
# Environments
|
| 128 |
+
.env
|
| 129 |
+
.venv
|
| 130 |
+
env/
|
| 131 |
+
venv/
|
| 132 |
+
ENV/
|
| 133 |
+
env.bak/
|
| 134 |
+
venv.bak/
|
| 135 |
+
|
| 136 |
+
# Spyder project settings
|
| 137 |
+
.spyderproject
|
| 138 |
+
.spyproject
|
| 139 |
+
|
| 140 |
+
# Rope project settings
|
| 141 |
+
.ropeproject
|
| 142 |
+
|
| 143 |
+
# mkdocs documentation
|
| 144 |
+
/site
|
| 145 |
+
|
| 146 |
+
# mypy
|
| 147 |
+
.mypy_cache/
|
| 148 |
+
.dmypy.json
|
| 149 |
+
dmypy.json
|
| 150 |
+
|
| 151 |
+
# Pyre type checker
|
| 152 |
+
.pyre/
|
| 153 |
+
|
| 154 |
+
# pytype static type analyzer
|
| 155 |
+
.pytype/
|
| 156 |
+
|
| 157 |
+
# Cython debug symbols
|
| 158 |
+
cython_debug/
|
| 159 |
+
|
| 160 |
+
# PyCharm
|
| 161 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 162 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 163 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 164 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 165 |
+
#.idea/
|
.gradio/certificate.pem
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-----BEGIN CERTIFICATE-----
|
| 2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
| 3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
| 4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
| 5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
| 6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
| 7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
| 8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
| 9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
| 10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
| 11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
| 12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
| 13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
| 14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
| 15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
| 16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
| 17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
| 18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
| 19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
| 20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
| 21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
| 22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
| 23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
| 24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
| 25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
| 26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
| 27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
| 28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
| 29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
| 30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
| 31 |
+
-----END CERTIFICATE-----
|
.gradio/flagged/dataset1.csv
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
input_text,output,timestamp
|
| 2 |
+
"'def greet(input_text):
|
| 3 |
+
sub_texts = text_splitter.split_text(input_text) # 修改为split_text获取文本列表
|
| 4 |
+
html_output = []
|
| 5 |
+
for sub_text in sub_texts:
|
| 6 |
+
tokenized = scoring_tokenizer(sub_text, truncation=True, return_tensors=""pt"", padding=True, return_token_type_ids=False).to(device)
|
| 7 |
+
labels = tokenized.input_ids[:, 1:]
|
| 8 |
+
with torch.no_grad():
|
| 9 |
+
logits_score = scoring_model(**tokenized).logits[:, :-1]
|
| 10 |
+
logits_ref = logits_score
|
| 11 |
+
crit, _ = criterion_fn(logits_ref, logits_score, labels)
|
| 12 |
+
|
| 13 |
+
crit = crit.cpu().numpy().item()
|
| 14 |
+
prob = prob_estimator.crit_to_prob(crit)
|
| 15 |
+
|
| 16 |
+
# 根据概率值设置颜色
|
| 17 |
+
if prob >= 0.7:
|
| 18 |
+
color = ""red""
|
| 19 |
+
elif prob >= 0.3:
|
| 20 |
+
color = ""orange""
|
| 21 |
+
else:
|
| 22 |
+
color = ""white""
|
| 23 |
+
|
| 24 |
+
# 创建带样式的HTML内容
|
| 25 |
+
html_output.append(f'<span style=""color: {color};"">{sub_text} (Probability: {prob:.2f})</span>')
|
| 26 |
+
|
| 27 |
+
# 用换行连接所有结果
|
| 28 |
+
return ""<br>"".join(html_output)
|
| 29 |
+
|
| 30 |
+
demo = gr.Interface(
|
| 31 |
+
fn=greet,
|
| 32 |
+
inputs=[""text""],
|
| 33 |
+
outputs=gr.HTML() # 修改为HTML输出组件
|
| 34 |
+
)","'<span style=""color: white;"">def greet(input_text):
|
| 35 |
+
sub_texts = text_splitter.split_text(input_text) # 修改为split_text获取文本列表
|
| 36 |
+
html_output = []
|
| 37 |
+
for (Probability: 0.09)</span><br><span style=""color: white;"">sub_text in sub_texts:
|
| 38 |
+
tokenized = scoring_tokenizer(sub_text, truncation=True, return_tensors=""pt"", padding=True, retur (Probability: 0.03)</span><br><span style=""color: white;"">n_token_type_ids=False).to(device)
|
| 39 |
+
labels = tokenized.input_ids[:, 1:]
|
| 40 |
+
with torch.no_grad():
|
| 41 |
+
logits_ (Probability: 0.05)</span><br><span style=""color: white;"">score = scoring_model(**tokenized).logits[:, :-1]
|
| 42 |
+
logits_ref = logits_score
|
| 43 |
+
crit, _ = criterion_fn(logit (Probability: 0.00)</span><br><span style=""color: white;"">s_ref, logits_score, labels)
|
| 44 |
+
|
| 45 |
+
crit = crit.cpu().numpy().item()
|
| 46 |
+
prob = prob_estimator.crit_to_prob(crit) (Probability: 0.02)</span><br><span style=""color: white;""># 根据概率值设置颜色
|
| 47 |
+
if prob >= 0.7:
|
| 48 |
+
color = ""red""
|
| 49 |
+
elif prob >= 0.3:
|
| 50 |
+
color = ""or (Probability: 0.09)</span><br><span style=""color: white;"">ange""
|
| 51 |
+
else:
|
| 52 |
+
color = ""white""
|
| 53 |
+
|
| 54 |
+
# 创建带样式的HTML内容
|
| 55 |
+
html_output.append(f'<span style=""color: (Probability: 0.19)</span><br><span style=""color: white;"">{color};"">{sub_text} (Probability: {prob:.2f})</span>')
|
| 56 |
+
|
| 57 |
+
# 用换行连接所有结果
|
| 58 |
+
return ""<br>"".join(html_output)
|
| 59 |
+
|
| 60 |
+
demo = gr.Int (Probability: 0.01)</span><br><span style=""color: white;"">erface(
|
| 61 |
+
fn=greet,
|
| 62 |
+
inputs=[""text""],
|
| 63 |
+
outputs=gr.HTML() # 修改为HTML输出组件
|
| 64 |
+
) (Probability: 0.06)</span>",2025-01-30 11:44:36.020197
|
DetectAnyLLM
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Subproject commit 2d182abad5143fc1183cfedca2a30f58c3d44e7e
|
LICENSE
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Pi-Lab License 1.0
|
| 2 |
+
|
| 3 |
+
Copyright 2025 Pi-Lab
|
| 4 |
+
|
| 5 |
+
Redistribution and use for non-commercial purpose in source and
|
| 6 |
+
binary forms, with or without modification, are permitted provided
|
| 7 |
+
that the following conditions are met:
|
| 8 |
+
|
| 9 |
+
1. Redistributions of source code must retain the above copyright
|
| 10 |
+
notice, this list of conditions and the following disclaimer.
|
| 11 |
+
|
| 12 |
+
2. Redistributions in binary form must reproduce the above copyright
|
| 13 |
+
notice, this list of conditions and the following disclaimer in
|
| 14 |
+
the documentation and/or other materials provided with the
|
| 15 |
+
distribution.
|
| 16 |
+
|
| 17 |
+
3. Neither the name of the copyright holder nor the names of its
|
| 18 |
+
contributors may be used to endorse or promote products derived
|
| 19 |
+
from this software without specific prior written permission.
|
| 20 |
+
|
| 21 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
| 22 |
+
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
| 23 |
+
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
| 24 |
+
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
| 25 |
+
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
| 26 |
+
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
| 27 |
+
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
| 28 |
+
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
| 29 |
+
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
| 30 |
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 31 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 32 |
+
|
| 33 |
+
In the event that redistribution and/or use for commercial purpose in
|
| 34 |
+
source or binary forms, with or without modification is required,
|
| 35 |
+
please contact the contributor(s) of the work.
|
README.md
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: DetectAnyLLM
|
| 3 |
+
emoji: 🔥
|
| 4 |
+
colorFrom: pink
|
| 5 |
+
colorTo: blue
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.46.1
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: other
|
| 11 |
+
short_description: '[ACMMM 2025] State-Of-The-Art AI-Text Detector'
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,539 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import os
|
| 5 |
+
import json
|
| 6 |
+
from core.model import DiscrepancyEstimator
|
| 7 |
+
import re
|
| 8 |
+
import docx
|
| 9 |
+
import spaces
|
| 10 |
+
from datasets import load_dataset
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def read_file_content(file):
|
| 14 |
+
if file is None:
|
| 15 |
+
return ""
|
| 16 |
+
if file.name.endswith('.txt'):
|
| 17 |
+
with open(file.name, 'r', encoding='utf-8') as f:
|
| 18 |
+
return f.read()
|
| 19 |
+
elif file.name.endswith('.docx'):
|
| 20 |
+
doc = docx.Document(file.name)
|
| 21 |
+
full_text = []
|
| 22 |
+
for para in doc.paragraphs:
|
| 23 |
+
full_text.append(para.text)
|
| 24 |
+
return '\n'.join(full_text)
|
| 25 |
+
return ""
|
| 26 |
+
|
| 27 |
+
def split_sentences(text):
|
| 28 |
+
"""根据句号、句点、分号分割文本成句子,同时保留分句符号。"""
|
| 29 |
+
sentences = re.split(r'([。.])', text)
|
| 30 |
+
combined_sentences = [sentences[i] + sentences[i+1] for i in range(0, len(sentences)-1, 2)]
|
| 31 |
+
if len(sentences) % 2 == 1:
|
| 32 |
+
combined_sentences.append(sentences[-1])
|
| 33 |
+
return [s.strip() for s in combined_sentences if s.strip()]
|
| 34 |
+
|
| 35 |
+
def count_words(sentence, language='Chinese'):
|
| 36 |
+
"""统计句子的词数。"""
|
| 37 |
+
return len(sentence.replace('\n', '').replace('\r', '').split()) if language != 'Chinese' else len(sentence.replace('\n', '').replace('\r', ''))
|
| 38 |
+
|
| 39 |
+
def segment_text(sentences, language='Chinese'):
|
| 40 |
+
"""按照要求拼接句子,确保不忽略第一段并处理最后一句话不足100词的情况。"""
|
| 41 |
+
result = []
|
| 42 |
+
current_segment = []
|
| 43 |
+
current_length = 0
|
| 44 |
+
|
| 45 |
+
for i, sentence in enumerate(sentences):
|
| 46 |
+
word_count = count_words(sentence, language)
|
| 47 |
+
|
| 48 |
+
if word_count > 100:
|
| 49 |
+
# 如果单个句子超过100词,考虑拼接
|
| 50 |
+
if i + 1 < len(sentences) and word_count + count_words(sentences[i + 1], language) <= 200:
|
| 51 |
+
# 拼接当前和下一个句子
|
| 52 |
+
if current_segment: # 先保存当前段
|
| 53 |
+
result.append(' '.join(current_segment) if language != 'Chinese' else ''.join(current_segment))
|
| 54 |
+
result.append((sentence + ' ' + sentences[i + 1]) if language != 'Chinese' else (sentence + sentences[i + 1]))
|
| 55 |
+
current_segment = []
|
| 56 |
+
current_length = 0
|
| 57 |
+
i += 1 # 跳过下一个句子
|
| 58 |
+
continue
|
| 59 |
+
else:
|
| 60 |
+
# 单独存放
|
| 61 |
+
if current_segment: # 先保存当前段
|
| 62 |
+
result.append(' '.join(current_segment) if language != 'Chinese' else ''.join(current_segment))
|
| 63 |
+
result.append(sentence)
|
| 64 |
+
current_segment = []
|
| 65 |
+
current_length = 0
|
| 66 |
+
else:
|
| 67 |
+
if current_length + word_count > 100:
|
| 68 |
+
# 当前段超过100词,保存并开始新段
|
| 69 |
+
if current_segment:
|
| 70 |
+
result.append(' '.join(current_segment) if language != 'Chinese' else ''.join(current_segment))
|
| 71 |
+
current_segment = [sentence]
|
| 72 |
+
current_length = word_count
|
| 73 |
+
else:
|
| 74 |
+
# 继续累积
|
| 75 |
+
current_segment.append(sentence)
|
| 76 |
+
current_length += word_count
|
| 77 |
+
|
| 78 |
+
# 处理最后一段
|
| 79 |
+
if current_segment:
|
| 80 |
+
if current_length < 100 and result and current_length + count_words(result[-1], language) <= 200:
|
| 81 |
+
# 如果最后一段不足100词,且可以与前一段合并
|
| 82 |
+
last_segment = result.pop() if result else ''
|
| 83 |
+
current_segment = (last_segment.split() if language != 'Chinese' else list(last_segment)) + current_segment
|
| 84 |
+
result.append(' '.join(current_segment) if language != 'Chinese' else ''.join(current_segment))
|
| 85 |
+
else:
|
| 86 |
+
# 直接添加最后一段
|
| 87 |
+
result.append(' '.join(current_segment) if language != 'Chinese' else ''.join(current_segment))
|
| 88 |
+
|
| 89 |
+
return result
|
| 90 |
+
|
| 91 |
+
def extract_latex_text(latex_source):
|
| 92 |
+
# 提取document环境中的内容
|
| 93 |
+
doc_pattern = re.compile(r'\\begin{document}(.*?)\\end{document}', re.DOTALL)
|
| 94 |
+
match = doc_pattern.search(latex_source)
|
| 95 |
+
content = match.group(1) if match else latex_source
|
| 96 |
+
|
| 97 |
+
# 删除注释(排除转义后的%)
|
| 98 |
+
content = re.sub(r'(?<!\\)%.*', '', content, flags=re.MULTILINE)
|
| 99 |
+
|
| 100 |
+
# 排除常见非文本环境
|
| 101 |
+
excluded_envs = ['figure', 'table', 'equation', 'align\*?', 'verbatim', 'lstlisting']
|
| 102 |
+
env_pattern = re.compile(
|
| 103 |
+
r'\\begin{(' + '|'.join(excluded_envs) + r')}.*?\\end{\1}',
|
| 104 |
+
re.DOTALL
|
| 105 |
+
)
|
| 106 |
+
content = env_pattern.sub('', content)
|
| 107 |
+
|
| 108 |
+
# 新增处理:删除所有cite命令及其内容
|
| 109 |
+
content = re.sub(r'\\cite(\[[^\]]*\])?\{[^}]*\}', '', content)
|
| 110 |
+
|
| 111 |
+
# 新增处理:删除行内table/figure命令及其内容
|
| 112 |
+
content = re.sub(r'\\(table|figure)\*?(\[[^\]]*\])?\{[^}]*\}', '', content)
|
| 113 |
+
|
| 114 |
+
# 删除简单命令(无参数)
|
| 115 |
+
content = re.sub(r'\\([a-zA-Z]+)\*?\b', '', content)
|
| 116 |
+
|
| 117 |
+
# 递归处理带参数的命令(最多迭代10次防止死循环)
|
| 118 |
+
for _ in range(10):
|
| 119 |
+
new_content = re.sub(
|
| 120 |
+
r'\\([a-zA-Z]+)\*?(?:\[.*?\])*{((?:[^{}]*|{[^{}]*})*)}',
|
| 121 |
+
lambda m: m.group(2),
|
| 122 |
+
content,
|
| 123 |
+
flags=re.DOTALL
|
| 124 |
+
)
|
| 125 |
+
if new_content == content:
|
| 126 |
+
break
|
| 127 |
+
content = new_content
|
| 128 |
+
|
| 129 |
+
# 处理特殊字符
|
| 130 |
+
replacements = {
|
| 131 |
+
'~': ' ', '\\&': '&', '\\$': '$', '\\%': '%',
|
| 132 |
+
'\\_': '_', '\\#': '#', '\\\\': '\n', '\n': ' ',
|
| 133 |
+
'“': '"', '”': '"', '‘': "'", '’': "'"
|
| 134 |
+
}
|
| 135 |
+
for k, v in replacements.items():
|
| 136 |
+
content = content.replace(k, v)
|
| 137 |
+
|
| 138 |
+
# 清理空白字符
|
| 139 |
+
content = re.sub(r'[ \t]+', ' ', content)
|
| 140 |
+
content = re.sub(r'\n{2,}', '\n\n', content)
|
| 141 |
+
return content.strip()
|
| 142 |
+
|
| 143 |
+
class ProbEstimator:
|
| 144 |
+
def __init__(self, ref_file_dir):
|
| 145 |
+
self.tasks = ["polish", "generate", "rewrite"]
|
| 146 |
+
self.real_crits = {"polish": [], "generate": [], "rewrite": []}
|
| 147 |
+
self.fake_crits = {"polish": [], "generate": [], "rewrite": []}
|
| 148 |
+
for task in self.tasks:
|
| 149 |
+
task_ref_data = load_dataset(ref_file_dir, data_files=f'{task}.json')['train']
|
| 150 |
+
self.real_crits[task].extend(task_ref_data['original_discrepancy'])
|
| 151 |
+
self.fake_crits[task].extend(task_ref_data['rewritten_discrepancy'])
|
| 152 |
+
print(f'ProbEstimator: total {sum([len(self.real_crits[task]) for task in self.tasks]) * 2} samples.')
|
| 153 |
+
|
| 154 |
+
def crit_to_prob(self, crit):
|
| 155 |
+
probs = {}
|
| 156 |
+
for task in self.tasks:
|
| 157 |
+
real_crits = self.real_crits[task]
|
| 158 |
+
fake_crits = self.fake_crits[task]
|
| 159 |
+
total_len = len(real_crits) + len(fake_crits)
|
| 160 |
+
offset = np.sort(np.abs(np.array(real_crits + fake_crits) - crit))[int(0.1*total_len)]
|
| 161 |
+
cnt_real = np.sum((np.array(real_crits) > crit - offset) & (np.array(real_crits) < crit + offset))
|
| 162 |
+
cnt_fake = np.sum((np.array(fake_crits) > crit - offset) & (np.array(fake_crits) < crit + offset))
|
| 163 |
+
probs[task] = (cnt_fake / (cnt_real + cnt_fake)) if (cnt_real + cnt_fake) > 0 else 0.5
|
| 164 |
+
return probs
|
| 165 |
+
|
| 166 |
+
device = 'cuda'
|
| 167 |
+
zh_prob_estimator = ProbEstimator(ref_file_dir="JiachenFu/Qwen2-0.5B-detectanyllm-detector-ref-zh")
|
| 168 |
+
en_prob_estimator = ProbEstimator(ref_file_dir="JiachenFu/Qwen2-0.5B-detectanyllm-detector-ref-en")
|
| 169 |
+
|
| 170 |
+
@spaces.GPU
|
| 171 |
+
def greet(mode, language, input_text):
|
| 172 |
+
if mode == "LaTex":
|
| 173 |
+
input_text = extract_latex_text(input_text)
|
| 174 |
+
split_texts = split_sentences(input_text)
|
| 175 |
+
sub_texts = segment_text(split_texts, language=language)
|
| 176 |
+
detected = []
|
| 177 |
+
if language == "Chinese":
|
| 178 |
+
model = DiscrepancyEstimator(load_directory="JiachenFu/Qwen2-0.5B-detectanyllm-detector-zh").to(device)
|
| 179 |
+
prob_estimator = zh_prob_estimator
|
| 180 |
+
else:
|
| 181 |
+
model = DiscrepancyEstimator(load_directory="JiachenFu/Qwen2-0.5B-detectanyllm-detector-en").to(device)
|
| 182 |
+
prob_estimator = en_prob_estimator
|
| 183 |
+
model.eval()
|
| 184 |
+
for i, sub_text in enumerate(sub_texts):
|
| 185 |
+
text_content = sub_text
|
| 186 |
+
print(f'processing {sub_text}')
|
| 187 |
+
tokens = model.scoring_tokenizer(
|
| 188 |
+
text_content, return_tensors='pt', padding=True, truncation=True, return_token_type_ids=False
|
| 189 |
+
)
|
| 190 |
+
print(f'tokenized')
|
| 191 |
+
input_ids = tokens['input_ids'].to(device)
|
| 192 |
+
attention_mask = tokens['attention_mask'].to(device)
|
| 193 |
+
with torch.no_grad():
|
| 194 |
+
output = model.get_discrepancy_of_scoring_and_reference_models(
|
| 195 |
+
input_ids_for_scoring_model=input_ids,
|
| 196 |
+
attention_mask_for_scoring_model=attention_mask,
|
| 197 |
+
input_ids_for_reference_model=None,
|
| 198 |
+
attention_mask_for_reference_model=None,
|
| 199 |
+
)
|
| 200 |
+
discrepancy = output['scoring_discrepancy']
|
| 201 |
+
discrepancy = discrepancy.cpu().numpy().item()
|
| 202 |
+
print(f'discrepancy: {discrepancy}')
|
| 203 |
+
probs = prob_estimator.crit_to_prob(discrepancy)
|
| 204 |
+
if discrepancy < 15:
|
| 205 |
+
for task in probs.keys():
|
| 206 |
+
probs[task] = 0.0
|
| 207 |
+
detected.append({
|
| 208 |
+
'order': i,
|
| 209 |
+
'text': text_content,
|
| 210 |
+
'words_count': len(text_content) if language == "Chinese" else len(text_content.split()),
|
| 211 |
+
'probs': probs
|
| 212 |
+
})
|
| 213 |
+
|
| 214 |
+
# 添加绝对定位的总概率显示
|
| 215 |
+
# 构建动画效果
|
| 216 |
+
html_output = '''
|
| 217 |
+
<style>
|
| 218 |
+
@keyframes reveal {
|
| 219 |
+
from { opacity: 0; }
|
| 220 |
+
to { opacity: 1; }
|
| 221 |
+
}
|
| 222 |
+
.reveal-char {
|
| 223 |
+
opacity: 0;
|
| 224 |
+
animation: reveal 0.2s forwards;
|
| 225 |
+
white-space: pre-wrap;
|
| 226 |
+
}
|
| 227 |
+
</style>
|
| 228 |
+
<div style="position: relative; padding-bottom: 60px; min-height: 120px;">
|
| 229 |
+
'''
|
| 230 |
+
|
| 231 |
+
current_delay = 0.0 # 当前动画延迟时间
|
| 232 |
+
char_duration = 0.001 # 每个字符的间隔时间
|
| 233 |
+
|
| 234 |
+
# 处理文本内容
|
| 235 |
+
for item in detected:
|
| 236 |
+
ai_generate_prob = item['probs']['generate']
|
| 237 |
+
ai_revise_prob = max(item['probs']['polish'], item['probs']['rewrite'])
|
| 238 |
+
prob = max(ai_generate_prob, ai_revise_prob)
|
| 239 |
+
if prob >= 0.75:
|
| 240 |
+
if ai_generate_prob >= ai_revise_prob:
|
| 241 |
+
color = "red"
|
| 242 |
+
item["generate"] = 1
|
| 243 |
+
item["revise"] = 0
|
| 244 |
+
else:
|
| 245 |
+
color = "orange"
|
| 246 |
+
item["generate"] = 0
|
| 247 |
+
item["revise"] = 1
|
| 248 |
+
else:
|
| 249 |
+
color = "black"
|
| 250 |
+
item["generate"] = 0
|
| 251 |
+
item["revise"] = 0
|
| 252 |
+
|
| 253 |
+
for char in item['text']:
|
| 254 |
+
html_output += f'<span class="reveal-char" style="color: {color}; animation-delay: {current_delay:.2f}s;">{char}</span>'
|
| 255 |
+
current_delay += char_duration
|
| 256 |
+
|
| 257 |
+
total_length = sum(item['words_count'] for item in detected)
|
| 258 |
+
# total_prob = sum(item['prob'] * item['words_count'] for item in detected) / total_length if total_length > 0 else 0
|
| 259 |
+
generate_prob = sum(item["generate"] * item["words_count"] for item in detected) / total_length if total_length > 0 else 0
|
| 260 |
+
revise_prob = sum(item["revise"] * item["words_count"] for item in detected) / total_length if total_length > 0 else 0
|
| 261 |
+
html_output += f'''
|
| 262 |
+
<div style="
|
| 263 |
+
position: absolute;
|
| 264 |
+
bottom: 0;
|
| 265 |
+
right: 0;
|
| 266 |
+
background-color: rgba(255, 255, 255, 0.9);
|
| 267 |
+
padding: 8px 12px;
|
| 268 |
+
border-radius: 4px;
|
| 269 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
| 270 |
+
border: 1px solid #e0e0e0;
|
| 271 |
+
font-size: 14px;
|
| 272 |
+
">
|
| 273 |
+
🤖 AI Generated Rate: <strong>{generate_prob:.2%}</strong><br>
|
| 274 |
+
✍️ AI Revised Rate: <strong>{revise_prob:.2%}</strong>
|
| 275 |
+
</div>
|
| 276 |
+
'''
|
| 277 |
+
|
| 278 |
+
html_output += '</div>'
|
| 279 |
+
return html_output
|
| 280 |
+
|
| 281 |
+
# 使用Blocks替代Interface以获得更好的自定义能力
|
| 282 |
+
# 修改CSS部分
|
| 283 |
+
with gr.Blocks(css="""
|
| 284 |
+
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700;800&display=swap');
|
| 285 |
+
|
| 286 |
+
:root {
|
| 287 |
+
--accent-color: #6366f1;
|
| 288 |
+
--text-color: #374151;
|
| 289 |
+
--border-color: #e5e7eb;
|
| 290 |
+
--background-light: #f9fafb;
|
| 291 |
+
--background-card: #ffffff;
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
body, .gradio-container {
|
| 295 |
+
background: var(--background-light);
|
| 296 |
+
font-family: 'Inter', sans-serif;
|
| 297 |
+
color: var(--text-color);
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
#header {
|
| 301 |
+
text-align: center;
|
| 302 |
+
padding: 2rem;
|
| 303 |
+
margin: 0 auto; /* Use gap for spacing, remove margin-bottom */
|
| 304 |
+
background-color: var(--background-card);
|
| 305 |
+
background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='40' height='40' viewBox='0 0 40 40'%3E%3Cg fill-rule='evenodd'%3E%3Cg fill='%23e5e7eb' fill-opacity='0.3'%3E%3Cpath d='M0 38.59l2.83-2.83 1.41 1.41L1.41 40H0v-1.41zM0 1.4l2.83 2.83 1.41-1.41L1.41 0H0v1.41zM38.59 40l-2.83-2.83 1.41-1.41L40 38.59V40h-1.41zM40 1.41l-2.83 2.83-1.41-1.41L38.59 0H40v1.41zM20 18.6l2.83-2.83 1.41 1.41L21.41 20l2.83 2.83-1.41 1.41L20 21.41l-2.83 2.83-1.41-1.41L18.59 20l-2.83-2.83 1.41-1.41L20 18.59z'/%3E%3C/g%3E%3C/g%3E%3C/svg%3E");
|
| 306 |
+
border: 1px solid var(--border-color);
|
| 307 |
+
border-radius: 16px;
|
| 308 |
+
box-shadow: 0 4px 12px rgba(0,0,0,0.05);
|
| 309 |
+
}
|
| 310 |
+
#title {
|
| 311 |
+
font-weight: 800;
|
| 312 |
+
font-size: 2.5em;
|
| 313 |
+
letter-spacing: -0.02em;
|
| 314 |
+
color: var(--text-color);
|
| 315 |
+
margin-bottom: 0.25em;
|
| 316 |
+
}
|
| 317 |
+
.detect-grad {
|
| 318 |
+
background: -webkit-linear-gradient(left, #ff8c8c, #ffc89e);
|
| 319 |
+
-webkit-background-clip: text;
|
| 320 |
+
-webkit-text-fill-color: transparent;
|
| 321 |
+
font-weight: 800;
|
| 322 |
+
}
|
| 323 |
+
.anyllm-grad {
|
| 324 |
+
background: -webkit-linear-gradient(left, #a0e6ff, #aaffd4);
|
| 325 |
+
-webkit-background-clip: text;
|
| 326 |
+
-webkit-text-fill-color: transparent;
|
| 327 |
+
font-weight: 800;
|
| 328 |
+
}
|
| 329 |
+
#authors {
|
| 330 |
+
font-size: 1.1em;
|
| 331 |
+
color: #6b7280;
|
| 332 |
+
margin: 0;
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
#main-container {
|
| 336 |
+
max-width: 1200px;
|
| 337 |
+
margin: 0 auto;
|
| 338 |
+
padding: 0 1rem;
|
| 339 |
+
gap: 2rem; /* Add gap for consistent spacing */
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
#controls-row {
|
| 343 |
+
justify-content: center;
|
| 344 |
+
gap: 2rem;
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
/* Custom styles for Radio Button Groups */
|
| 348 |
+
#controls-row > div {
|
| 349 |
+
background-color: var(--background-card);
|
| 350 |
+
background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='40' height='40' viewBox='0 0 40 40'%3E%3Cg fill-rule='evenodd'%3E%3Cg fill='%23e5e7eb' fill-opacity='0.3'%3E%3Cpath d='M0 38.59l2.83-2.83 1.41 1.41L1.41 40H0v-1.41zM0 1.4l2.83 2.83 1.41-1.41L1.41 0H0v1.41zM38.59 40l-2.83-2.83 1.41-1.41L40 38.59V40h-1.41zM40 1.41l-2.83 2.83-1.41-1.41L38.59 0H40v1.41zM20 18.6l2.83-2.83 1.41 1.41L21.41 20l2.83 2.83-1.41 1.41L20 21.41l-2.83 2.83-1.41-1.41L18.59 20l-2.83-2.83 1.41-1.41L20 18.59z'/%3E%3C/g%3E%3C/g%3E%3C/svg%3E");
|
| 351 |
+
border: 1px solid var(--border-color);
|
| 352 |
+
border-radius: 16px;
|
| 353 |
+
padding: 1rem;
|
| 354 |
+
box-shadow: 0 4px 12px rgba(0,0,0,0.05);
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
#controls-row .gradio-button {
|
| 358 |
+
border-radius: 10px !important;
|
| 359 |
+
transition: background-color 0.2s ease, color 0.2s ease;
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
#controls-row .gradio-button.selected {
|
| 363 |
+
background: var(--accent-color) !important;
|
| 364 |
+
color: white !important;
|
| 365 |
+
border-color: var(--accent-color) !important;
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
#content-row {
|
| 369 |
+
gap: 1.5rem;
|
| 370 |
+
}
|
| 371 |
+
|
| 372 |
+
.card {
|
| 373 |
+
background-color: var(--background-card);
|
| 374 |
+
background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='40' height='40' viewBox='0 0 40 40'%3E%3Cg fill-rule='evenodd'%3E%3Cg fill='%23e5e7eb' fill-opacity='0.3'%3E%3Cpath d='M0 38.59l2.83-2.83 1.41 1.41L1.41 40H0v-1.41zM0 1.4l2.83 2.83 1.41-1.41L1.41 0H0v1.41zM38.59 40l-2.83-2.83 1.41-1.41L40 38.59V40h-1.41zM40 1.41l-2.83 2.83-1.41-1.41L38.59 0H40v1.41zM20 18.6l2.83-2.83 1.41 1.41L21.41 20l2.83 2.83-1.41 1.41L20 21.41l-2.83 2.83-1.41-1.41L18.59 20l-2.83-2.83 1.41-1.41L20 18.59z'/%3E%3C/g%3E%3C/g%3E%3C/svg%3E");
|
| 375 |
+
border: 1px solid var(--border-color);
|
| 376 |
+
border-radius: 16px;
|
| 377 |
+
padding: 1.5rem;
|
| 378 |
+
box-shadow: 0 4px 12px rgba(0,0,0,0.05);
|
| 379 |
+
height: 100%;
|
| 380 |
+
display: flex;
|
| 381 |
+
flex-direction: column;
|
| 382 |
+
gap: 1rem;
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
.card-title {
|
| 386 |
+
font-weight: 600;
|
| 387 |
+
font-size: 1.2rem;
|
| 388 |
+
color: var(--text-color);
|
| 389 |
+
padding-bottom: 0.75rem;
|
| 390 |
+
border-bottom: 1px solid var(--border-color);
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
#input-text textarea {
|
| 394 |
+
flex-grow: 1;
|
| 395 |
+
border: none !important;
|
| 396 |
+
box-shadow: none !important;
|
| 397 |
+
padding: 0 !important;
|
| 398 |
+
font-size: 1.1em;
|
| 399 |
+
line-height: 1.7;
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
#result-html {
|
| 403 |
+
flex-grow: 1;
|
| 404 |
+
font-size: 1.1em;
|
| 405 |
+
line-height: 1.7;
|
| 406 |
+
overflow-y: auto;
|
| 407 |
+
height: 520px;
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
#input-footer {
|
| 411 |
+
display: flex;
|
| 412 |
+
justify-content: space-between;
|
| 413 |
+
align-items: center;
|
| 414 |
+
margin-top: auto; /* Push to bottom */
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
#char-counter {
|
| 418 |
+
font-size: 0.9em;
|
| 419 |
+
color: #9ca3af;
|
| 420 |
+
}
|
| 421 |
+
#char-counter.error {
|
| 422 |
+
color: #ef4444;
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
#submit-btn {
|
| 426 |
+
flex-grow: 1;
|
| 427 |
+
max-width: 200px;
|
| 428 |
+
font-size: 1.05em;
|
| 429 |
+
font-weight: 600;
|
| 430 |
+
background: var(--accent-color);
|
| 431 |
+
color: white;
|
| 432 |
+
border-radius: 10px;
|
| 433 |
+
}
|
| 434 |
+
#submit-btn:hover {
|
| 435 |
+
background: #4f46e5;
|
| 436 |
+
}
|
| 437 |
+
|
| 438 |
+
.disclaimer {
|
| 439 |
+
text-align: center;
|
| 440 |
+
margin: 0 auto; /* Remove vertical margins */
|
| 441 |
+
color: #64748b;
|
| 442 |
+
font-size: 1.1em;
|
| 443 |
+
max-width: 800px;
|
| 444 |
+
}
|
| 445 |
+
/* Reveal 动画更丝滑 */
|
| 446 |
+
@keyframes reveal {
|
| 447 |
+
from { opacity: 0; }
|
| 448 |
+
to { opacity: 1; }
|
| 449 |
+
}
|
| 450 |
+
.reveal-char {
|
| 451 |
+
opacity: 0;
|
| 452 |
+
animation: reveal 0.2s forwards;
|
| 453 |
+
white-space: pre-wrap;
|
| 454 |
+
}
|
| 455 |
+
""") as demo:
|
| 456 |
+
with gr.Column(elem_id="main-container"):
|
| 457 |
+
gr.Markdown("""
|
| 458 |
+
<div id="header">
|
| 459 |
+
<h1 id="title"><span class="detect-grad">Detect</span><span class="anyllm-grad">AnyLLM</span>: Towards Generalizable and Robust Detection of Machine-Generated Text Across Domains and Models</h1>
|
| 460 |
+
<p id="authors">Jiachen Fu, Chun-Le Guo, Chongyi Li</p>
|
| 461 |
+
</div>
|
| 462 |
+
""")
|
| 463 |
+
|
| 464 |
+
with gr.Row(elem_id="controls-row"):
|
| 465 |
+
language_radio = gr.Radio(
|
| 466 |
+
choices=["English", "Chinese"],
|
| 467 |
+
value="English",
|
| 468 |
+
label="🌐 Language",
|
| 469 |
+
interactive=True
|
| 470 |
+
)
|
| 471 |
+
mode_radio = gr.Radio(
|
| 472 |
+
choices=["Text-Only", "LaTex"],
|
| 473 |
+
value="Text-Only",
|
| 474 |
+
label="✍️ Input Type",
|
| 475 |
+
interactive=True
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
with gr.Row(equal_height=True, elem_id="content-row"):
|
| 479 |
+
with gr.Column(scale=1, min_width=500):
|
| 480 |
+
with gr.Column(elem_classes="card"):
|
| 481 |
+
gr.HTML('<div class="card-title">📝 Input</div>')
|
| 482 |
+
upload_btn = gr.File(
|
| 483 |
+
label="Upload File (txt, docx)",
|
| 484 |
+
file_types=['.txt', '.docx'],
|
| 485 |
+
elem_id="upload-btn"
|
| 486 |
+
)
|
| 487 |
+
input_text = gr.Textbox(
|
| 488 |
+
show_label=False,
|
| 489 |
+
placeholder="Enter text to detect or upload a file...",
|
| 490 |
+
lines=15,
|
| 491 |
+
elem_id="input-text",
|
| 492 |
+
max_length=100000,
|
| 493 |
+
)
|
| 494 |
+
with gr.Row(elem_id="input-footer"):
|
| 495 |
+
counter_html = gr.HTML("<div id='char-counter'>0/100000</div>")
|
| 496 |
+
submit_btn = gr.Button("✨ Detect", variant="primary", elem_id="submit-btn")
|
| 497 |
+
|
| 498 |
+
with gr.Column(scale=1, min_width=500):
|
| 499 |
+
with gr.Column(elem_classes="card"):
|
| 500 |
+
gr.HTML('<div class="card-title">🔍 Result</div>')
|
| 501 |
+
result = gr.HTML(elem_id="result-html")
|
| 502 |
+
|
| 503 |
+
gr.HTML("""
|
| 504 |
+
<div class="disclaimer">
|
| 505 |
+
💡 <i><b style="color: red;">Red fonts</b> indicate a high probability of AI generation. <b style="color: orange;">Orange fonts</b> indicate a high probability of AI revision or polishing. The detection results are for reference only.</i>
|
| 506 |
+
</div>
|
| 507 |
+
""")
|
| 508 |
+
|
| 509 |
+
upload_btn.upload(
|
| 510 |
+
read_file_content,
|
| 511 |
+
inputs=upload_btn,
|
| 512 |
+
outputs=input_text
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
input_text.input(
|
| 516 |
+
None,
|
| 517 |
+
[input_text],
|
| 518 |
+
None,
|
| 519 |
+
js="""
|
| 520 |
+
(text) => {
|
| 521 |
+
setTimeout(() => {
|
| 522 |
+
const counter = document.getElementById("char-counter");
|
| 523 |
+
if (counter) {
|
| 524 |
+
const length = text.length;
|
| 525 |
+
counter.innerHTML = `${length}/100000`;
|
| 526 |
+
counter.classList.toggle("error", length > 100000);
|
| 527 |
+
}
|
| 528 |
+
}, 0);
|
| 529 |
+
return text;
|
| 530 |
+
}
|
| 531 |
+
"""
|
| 532 |
+
)
|
| 533 |
+
submit_btn.click(
|
| 534 |
+
greet,
|
| 535 |
+
inputs=[mode_radio, language_radio, input_text],
|
| 536 |
+
outputs=result
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
demo.launch(share=True)
|
core/model.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import sys
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
import copy
|
| 8 |
+
from peft import get_peft_model, LoraConfig, TaskType, AutoPeftModelForCausalLM
|
| 9 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 10 |
+
|
| 11 |
+
def from_pretrained(cls, model_name, kwargs, cache_dir):
|
| 12 |
+
# use local model if it exists
|
| 13 |
+
if "/" in model_name:
|
| 14 |
+
local_path = os.path.join(cache_dir, model_name.split("/")[-1])
|
| 15 |
+
else:
|
| 16 |
+
local_path = os.path.join(cache_dir, model_name)
|
| 17 |
+
|
| 18 |
+
if os.path.exists(local_path):
|
| 19 |
+
return cls.from_pretrained(local_path, **kwargs)
|
| 20 |
+
return cls.from_pretrained(model_name, **kwargs, cache_dir=cache_dir, device_map='auto')
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class DiscrepancyEstimator(nn.Module):
|
| 24 |
+
def __init__(self,
|
| 25 |
+
scoring_model_name: str=None,
|
| 26 |
+
reference_model_name: str=None,
|
| 27 |
+
scoring_model: AutoModelForCausalLM=None,
|
| 28 |
+
reference_model: AutoModelForCausalLM=None,
|
| 29 |
+
scoring_tokenizer: AutoTokenizer=None,
|
| 30 |
+
reference_tokenizer: AutoTokenizer=None,
|
| 31 |
+
cache_dir: str=None,
|
| 32 |
+
train_method: str='DDL',
|
| 33 |
+
pretrained_ckpt: str=None,
|
| 34 |
+
):
|
| 35 |
+
super().__init__()
|
| 36 |
+
assert train_method in ['DDL', 'SPO'], 'train_method should be DDL or SPO.'
|
| 37 |
+
self.train_method = train_method
|
| 38 |
+
self.cache_dir = cache_dir
|
| 39 |
+
if pretrained_ckpt is not None:
|
| 40 |
+
self.load_pretrained(pretrained_ckpt)
|
| 41 |
+
else:
|
| 42 |
+
if scoring_model_name is not None:
|
| 43 |
+
if 'gpt-j' in scoring_model_name or 'GPT-J' in scoring_model_name:
|
| 44 |
+
model_kwargs = dict(
|
| 45 |
+
torch_dtype=torch.float16,
|
| 46 |
+
revision='float16'
|
| 47 |
+
)
|
| 48 |
+
else:
|
| 49 |
+
model_kwargs = {}
|
| 50 |
+
self.scoring_model_name = scoring_model_name
|
| 51 |
+
self.scoring_model = from_pretrained(AutoModelForCausalLM,
|
| 52 |
+
scoring_model_name,
|
| 53 |
+
cache_dir=cache_dir,
|
| 54 |
+
kwargs=model_kwargs)
|
| 55 |
+
self.scoring_tokenizer = from_pretrained(AutoTokenizer,
|
| 56 |
+
scoring_model_name,
|
| 57 |
+
kwargs={'padding_side': 'right',
|
| 58 |
+
'use_fast': True if 'facebook/opt-' not in scoring_model_name else False},
|
| 59 |
+
cache_dir=cache_dir,)
|
| 60 |
+
else:
|
| 61 |
+
if scoring_model is None or scoring_tokenizer is None:
|
| 62 |
+
raise ValueError('You should provide scoring_model_name or scoring_model and scoring_tokenizer.')
|
| 63 |
+
self.scoring_model = scoring_model
|
| 64 |
+
self.scoring_tokenizer = scoring_tokenizer
|
| 65 |
+
self.scoring_model_name = scoring_model.config._name_or_path
|
| 66 |
+
if self.scoring_tokenizer.pad_token is None:
|
| 67 |
+
self.scoring_tokenizer.pad_token = self.scoring_tokenizer.eos_token
|
| 68 |
+
self.scoring_tokenizer.pad_token_id = self.scoring_tokenizer.eos_token_id
|
| 69 |
+
|
| 70 |
+
if reference_model_name is not None:
|
| 71 |
+
if 'gpt-j' in reference_model_name or 'GPT-J' in reference_model_name:
|
| 72 |
+
model_kwargs = dict(
|
| 73 |
+
torch_dtype=torch.float16,
|
| 74 |
+
revision='float16'
|
| 75 |
+
)
|
| 76 |
+
else:
|
| 77 |
+
model_kwargs = {}
|
| 78 |
+
self.reference_model = from_pretrained(AutoModelForCausalLM,
|
| 79 |
+
reference_model_name,
|
| 80 |
+
cache_dir=cache_dir,
|
| 81 |
+
kwargs=model_kwargs)
|
| 82 |
+
self.reference_tokenizer = from_pretrained(AutoTokenizer,
|
| 83 |
+
reference_model_name,
|
| 84 |
+
kwargs={'padding_side': 'right',
|
| 85 |
+
'use_fast': True if 'facebook/opt-' not in reference_model_name else False},
|
| 86 |
+
cache_dir=cache_dir,)
|
| 87 |
+
self.reference_model_name = reference_model_name
|
| 88 |
+
else:
|
| 89 |
+
if reference_model is None and reference_tokenizer is None:
|
| 90 |
+
if train_method == 'DDL':
|
| 91 |
+
self.reference_model = None
|
| 92 |
+
self.reference_tokenizer = None
|
| 93 |
+
self.reference_model_name = None
|
| 94 |
+
else:
|
| 95 |
+
self.reference_model = copy.deepcopy(self.scoring_model)
|
| 96 |
+
self.reference_tokenizer = self.scoring_tokenizer
|
| 97 |
+
self.reference_model_name = self.reference_model.config._name_or_path
|
| 98 |
+
elif reference_model is not None and reference_tokenizer is not None:
|
| 99 |
+
self.reference_model = reference_model
|
| 100 |
+
self.reference_tokenizer = reference_tokenizer
|
| 101 |
+
self.reference_model_name = reference_model.config._name_or_path
|
| 102 |
+
else:
|
| 103 |
+
raise ValueError('You should provide reference_model and reference_tokenizer at the same time.')
|
| 104 |
+
|
| 105 |
+
if self.reference_tokenizer is not None:
|
| 106 |
+
if self.reference_tokenizer.pad_token is None:
|
| 107 |
+
self.reference_tokenizer.pad_token = self.reference_tokenizer.eos_token
|
| 108 |
+
self.reference_tokenizer.pad_token_id = self.reference_tokenizer.eos_token_id
|
| 109 |
+
|
| 110 |
+
def add_lora_config(self, lora_config: LoraConfig):
|
| 111 |
+
self.lora_config = lora_config
|
| 112 |
+
self.scoring_model = get_peft_model(self.scoring_model, self.lora_config)
|
| 113 |
+
|
| 114 |
+
def load_pretrained(self, load_directory, load_directory_ref=None):
|
| 115 |
+
"""
|
| 116 |
+
Load the model's state_dict from the specified directory.
|
| 117 |
+
"""
|
| 118 |
+
if not os.path.exists(load_directory):
|
| 119 |
+
raise ValueError(f"Directory {load_directory} does not exist.")
|
| 120 |
+
|
| 121 |
+
if 'gpt-j' in load_directory or 'GPT-J' in load_directory:
|
| 122 |
+
model_kwargs = dict(
|
| 123 |
+
torch_dtype=torch.float16,
|
| 124 |
+
revision='float16'
|
| 125 |
+
)
|
| 126 |
+
else:
|
| 127 |
+
model_kwargs = {}
|
| 128 |
+
|
| 129 |
+
self.scoring_model = AutoPeftModelForCausalLM.from_pretrained(load_directory, **model_kwargs)
|
| 130 |
+
self.scoring_tokenizer = AutoTokenizer.from_pretrained(load_directory)
|
| 131 |
+
self.scoring_model_name = self.scoring_model.config._name_or_path
|
| 132 |
+
|
| 133 |
+
if load_directory_ref:
|
| 134 |
+
self.reference_model = AutoModelForCausalLM.from_pretrained(load_directory_ref, **model_kwargs)
|
| 135 |
+
self.reference_tokenizer = AutoTokenizer.from_pretrained(load_directory_ref)
|
| 136 |
+
self.reference_model_name = self.reference_model.config._name_or_path
|
| 137 |
+
else:
|
| 138 |
+
self.reference_model = None
|
| 139 |
+
self.reference_tokenizer = None
|
| 140 |
+
self.reference_model_name = None
|
| 141 |
+
|
| 142 |
+
if self.scoring_tokenizer.pad_token is None:
|
| 143 |
+
self.scoring_tokenizer.pad_token = self.scoring_tokenizer.eos_token
|
| 144 |
+
self.scoring_tokenizer.pad_token_id = self.scoring_tokenizer.eos_token_id
|
| 145 |
+
if self.reference_tokenizer is not None:
|
| 146 |
+
if self.reference_tokenizer.pad_token is None:
|
| 147 |
+
self.reference_tokenizer.pad_token = self.reference_tokenizer.eos_token
|
| 148 |
+
self.reference_tokenizer.pad_token_id = self.reference_tokenizer.eos_token_id
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def get_sampling_discrepancy_analytic(self, reference_logits, scoring_logits, labels, attention_mask):
|
| 152 |
+
|
| 153 |
+
if reference_logits.size(-1) != scoring_logits.size(-1):
|
| 154 |
+
vocab_size = min(reference_logits.size(-1), scoring_logits.size(-1))
|
| 155 |
+
reference_logits = reference_logits[:, :, :vocab_size]
|
| 156 |
+
scoring_logits = scoring_logits[:, :, :vocab_size]
|
| 157 |
+
|
| 158 |
+
labels = labels.unsqueeze(-1) if labels.ndim == scoring_logits.ndim - 1 else labels
|
| 159 |
+
lprobs_score = torch.log_softmax(scoring_logits, dim=-1)
|
| 160 |
+
probs_ref = torch.softmax(reference_logits, dim=-1)
|
| 161 |
+
|
| 162 |
+
log_likelihood = lprobs_score.gather(dim=-1, index=labels).squeeze(-1)
|
| 163 |
+
mean_ref = (probs_ref * lprobs_score).sum(dim=-1)
|
| 164 |
+
var_ref = (probs_ref * torch.square(lprobs_score)).sum(dim=-1) - torch.square(mean_ref)
|
| 165 |
+
|
| 166 |
+
mask = attention_mask[:, 1:].float() # [bsz, seq_len-1], 1 for non-pad, 0 for pad
|
| 167 |
+
log_likelihood_sum = (log_likelihood * mask).sum(dim=-1) # [bsz], sum over non-pad tokens
|
| 168 |
+
mean_ref_sum = (mean_ref * mask).sum(dim=-1) # [bsz], sum over non-pad tokens
|
| 169 |
+
var_ref_sum = (var_ref * mask).sum(dim=-1) # [bsz], sum over non-pad tokens
|
| 170 |
+
discrepancy = (log_likelihood_sum - mean_ref_sum) / (var_ref_sum.sqrt() + 1e-8) # [bsz], avoid division by zero
|
| 171 |
+
|
| 172 |
+
return discrepancy, log_likelihood_sum
|
| 173 |
+
|
| 174 |
+
def get_discrepancy_of_scoring_and_reference_models(self,
|
| 175 |
+
input_ids_for_scoring_model,
|
| 176 |
+
attention_mask_for_scoring_model,
|
| 177 |
+
input_ids_for_reference_model=None,
|
| 178 |
+
attention_mask_for_reference_model=None,
|
| 179 |
+
) -> dict:
|
| 180 |
+
labels = input_ids_for_scoring_model[:, 1:] # shape: [bsz, sentence_len - 1]
|
| 181 |
+
scoring_logits = self.scoring_model(input_ids_for_scoring_model,
|
| 182 |
+
attention_mask=attention_mask_for_scoring_model).logits[:,:-1,:]
|
| 183 |
+
if self.reference_model is not None:
|
| 184 |
+
assert input_ids_for_reference_model is not None and attention_mask_for_reference_model is not None, \
|
| 185 |
+
"If reference_model is provided, you should provide reference_tokenizer to dataset initialization."
|
| 186 |
+
with torch.no_grad():
|
| 187 |
+
# check if tokenizer is the match
|
| 188 |
+
reference_labels = input_ids_for_reference_model[:, 1:] # shape: [bsz, sentence_len]
|
| 189 |
+
assert torch.all(reference_labels == labels), \
|
| 190 |
+
"Tokenizer is mismatch."
|
| 191 |
+
reference_logits = self.reference_model(input_ids_for_reference_model,
|
| 192 |
+
attention_mask=attention_mask_for_reference_model).logits[:,:-1,:]
|
| 193 |
+
else:
|
| 194 |
+
reference_logits = scoring_logits
|
| 195 |
+
|
| 196 |
+
if self.reference_model is not None:
|
| 197 |
+
discrepancy_ref, logprob_ref = self.get_sampling_discrepancy_analytic(reference_logits, reference_logits,
|
| 198 |
+
labels, attention_mask=attention_mask_for_reference_model)
|
| 199 |
+
else:
|
| 200 |
+
discrepancy_ref, logprob_ref = None, None
|
| 201 |
+
discrepancy_score, logprob_score = self.get_sampling_discrepancy_analytic(reference_logits, scoring_logits,
|
| 202 |
+
labels, attention_mask=attention_mask_for_scoring_model)
|
| 203 |
+
|
| 204 |
+
return {
|
| 205 |
+
'scoring_discrepancy': discrepancy_score,
|
| 206 |
+
'scoring_logprob': logprob_score,
|
| 207 |
+
'reference_discrepancy': discrepancy_ref,
|
| 208 |
+
'reference_logprob': logprob_ref,
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
def forward(self,
|
| 212 |
+
scoring_original_input_ids,
|
| 213 |
+
scoring_original_attention_mask,
|
| 214 |
+
scoring_rewritten_input_ids,
|
| 215 |
+
scoring_rewritten_attention_mask,
|
| 216 |
+
reference_original_input_ids=None,
|
| 217 |
+
reference_original_attention_mask=None,
|
| 218 |
+
reference_rewritten_input_ids=None,
|
| 219 |
+
reference_rewritten_attention_mask=None,
|
| 220 |
+
) -> dict:
|
| 221 |
+
if self.train_method == 'SPO':
|
| 222 |
+
assert reference_original_input_ids is not None and reference_original_attention_mask is not None, \
|
| 223 |
+
"If train_method is SPO, you should provide reference_original_input_ids and reference_original_attention_mask."
|
| 224 |
+
assert reference_rewritten_input_ids is not None and reference_rewritten_attention_mask is not None, \
|
| 225 |
+
"If train_method is SPO, you should provide reference_rewritten_input_ids and reference_rewritten_attention_mask."
|
| 226 |
+
elif self.train_method == 'DDL':
|
| 227 |
+
assert reference_original_input_ids is None and reference_original_attention_mask is None, \
|
| 228 |
+
"If train_method is DDL, you should not provide reference_original_input_ids and reference_original_attention_mask."
|
| 229 |
+
assert reference_rewritten_input_ids is None and reference_rewritten_attention_mask is None, \
|
| 230 |
+
"If train_method is DDL, you should not provide reference_rewritten_input_ids and reference_rewritten_attention_mask."
|
| 231 |
+
else:
|
| 232 |
+
raise ValueError('train_method should be DDL or SPO.')
|
| 233 |
+
original_output = self.get_discrepancy_of_scoring_and_reference_models(
|
| 234 |
+
input_ids_for_scoring_model=scoring_original_input_ids,
|
| 235 |
+
attention_mask_for_scoring_model=scoring_original_attention_mask,
|
| 236 |
+
input_ids_for_reference_model=reference_original_input_ids,
|
| 237 |
+
attention_mask_for_reference_model=reference_original_attention_mask,
|
| 238 |
+
)
|
| 239 |
+
rewritten_output = self.get_discrepancy_of_scoring_and_reference_models(
|
| 240 |
+
input_ids_for_scoring_model=scoring_rewritten_input_ids,
|
| 241 |
+
attention_mask_for_scoring_model=scoring_rewritten_attention_mask,
|
| 242 |
+
input_ids_for_reference_model=reference_rewritten_input_ids,
|
| 243 |
+
attention_mask_for_reference_model=reference_rewritten_attention_mask,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
return {
|
| 247 |
+
'scoring_original_discrepancy': original_output['scoring_discrepancy'],
|
| 248 |
+
'scoring_original_logprob': original_output['scoring_logprob'],
|
| 249 |
+
'scoring_rewritten_discrepancy': rewritten_output['scoring_discrepancy'],
|
| 250 |
+
'scoring_rewritten_logprob': rewritten_output['scoring_logprob'],
|
| 251 |
+
'reference_original_discrepancy': original_output['reference_discrepancy'],
|
| 252 |
+
'reference_original_logprob': original_output['reference_logprob'],
|
| 253 |
+
'reference_rewritten_discrepancy': rewritten_output['reference_discrepancy'],
|
| 254 |
+
'reference_rewritten_logprob': rewritten_output['reference_logprob'],
|
| 255 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
peft
|
| 2 |
+
torch
|
| 3 |
+
transformers
|
| 4 |
+
protobuf
|
| 5 |
+
python-docx
|
| 6 |
+
gradio
|
| 7 |
+
numpy
|
| 8 |
+
huggingface_hub
|
| 9 |
+
datasets
|
| 10 |
+
spaces
|