File size: 5,464 Bytes
849f4a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6b6b2d
849f4a8
 
 
 
 
 
 
 
 
 
a6b6b2d
 
 
 
 
849f4a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import re
from tokenizers import normalizers
from tokenizers.normalizers import Sequence, Replace, Strip, NFKC
from tokenizers import Regex
import os

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Updated model paths with the correct model 1 location
model1_path = "https://huggingface.co/spaces/SzegedAI/AI_Detector/resolve/main/modernbert.bin"
model2_path = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed12"
model3_path = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed22"

print("Loading models...")

tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")

# Load Model 1 from the correct URL
try:
    print("Loading model1 from SzegedAI/AI_Detector...")
    model_1 = AutoModelForSequenceClassification.from_pretrained("answerdotai/ModernBERT-base", num_labels=41)
    model_1.load_state_dict(torch.hub.load_state_dict_from_url(model1_path, map_location=device))
    model_1.to(device).eval()
    print("✅ Model 1 loaded successfully")
except Exception as e:
    print(f"⚠️ Could not load model 1: {e}")
    model_1 = AutoModelForSequenceClassification.from_pretrained("answerdotai/ModernBERT-base", num_labels=41)
    model_1.to(device).eval()

# Load Model 2
try:
    print("Loading model2 from mihalykiss/modernbert_2...")
    model_2 = AutoModelForSequenceClassification.from_pretrained("answerdotai/ModernBERT-base", num_labels=41)
    model_2.load_state_dict(torch.hub.load_state_dict_from_url(model2_path, map_location=device))
    model_2.to(device).eval()
    print("✅ Model 2 loaded successfully")
except Exception as e:
    print(f"⚠️ Could not load model 2: {e}")
    model_2 = AutoModelForSequenceClassification.from_pretrained("answerdotai/ModernBERT-base", num_labels=41)
    model_2.to(device).eval()

# Load Model 3
try:
    print("Loading model3 from mihalykiss/modernbert_2...")
    model_3 = AutoModelForSequenceClassification.from_pretrained("answerdotai/ModernBERT-base", num_labels=41)
    model_3.load_state_dict(torch.hub.load_state_dict_from_url(model3_path, map_location=device))
    model_3.to(device).eval()
    print("✅ Model 3 loaded successfully")
except Exception as e:
    print(f"⚠️ Could not load model 3: {e}")
    model_3 = AutoModelForSequenceClassification.from_pretrained("answerdotai/ModernBERT-base", num_labels=41)
    model_3.to(device).eval()

print("✅ All models loaded successfully!")

label_mapping = {
    0: '13B', 1: '30B', 2: '65B', 3: '7B', 4: 'GLM130B', 5: 'bloom_7b',
    6: 'bloomz', 7: 'cohere', 8: 'davinci', 9: 'dolly', 10: 'dolly-v2-12b',
    11: 'flan_t5_base', 12: 'flan_t5_large', 13: 'flan_t5_small',
    14: 'flan_t5_xl', 15: 'flan_t5_xxl', 16: 'gemma-7b-it', 17: 'gemma2-9b-it',
    18: 'gpt-3.5-turbo', 19: 'gpt-35', 20: 'gpt4', 21: 'gpt4o',
    22: 'gpt_j', 23: 'gpt_neox', 24: 'human', 25: 'llama3-70b', 26: 'llama3-8b',
    27: 'mixtral-8x7b', 28: 'opt_1.3b', 29: 'opt_125m', 30: 'opt_13b',
    31: 'opt_2.7b', 32: 'opt_30b', 33: 'opt_350m', 34: 'opt_6.7b',
    35: 'opt_iml_30b', 36: 'opt_iml_max_1.3b', 37: 't0_11b', 38: 't0_3b',
    39: 'text-davinci-002', 40: 'text-davinci-003'
}

def clean_text(text: str) -> str:
    text = re.sub(r'\s{2,}', ' ', text)
    text = re.sub(r'\s+([,.;:?!])', r'\1', text)
    return text

newline_to_space = Replace(Regex(r'\s*\n\s*'), " ")
join_hyphen_break = Replace(Regex(r'(\w+)[--]\s*\n\s*(\w+)'), r"\1\2")

tokenizer.backend_tokenizer.normalizer = Sequence([
    tokenizer.backend_tokenizer.normalizer,
    join_hyphen_break,
    newline_to_space,
    Strip()
])

def classify_text(text):
    """
    Classify text using ModernBERT ensemble
    Author: deveshpunjabi
    Date: 2025-01-15 07:07:03 UTC
    """
    cleaned_text = clean_text(text)
    if not text.strip():
        return "⚠️ Please enter some text to analyze"

    inputs = tokenizer(cleaned_text, return_tensors="pt", truncation=True, padding=True).to(device)

    with torch.no_grad():
        logits_1 = model_1(**inputs).logits
        logits_2 = model_2(**inputs).logits
        logits_3 = model_3(**inputs).logits

        softmax_1 = torch.softmax(logits_1, dim=1)
        softmax_2 = torch.softmax(logits_2, dim=1)
        softmax_3 = torch.softmax(logits_3, dim=1)

        averaged_probabilities = (softmax_1 + softmax_2 + softmax_3) / 3
        probabilities = averaged_probabilities[0]

    ai_probs = probabilities.clone()
    ai_probs[24] = 0
    ai_total_prob = ai_probs.sum().item() * 100
    human_prob = 100 - ai_total_prob

    ai_argmax_index = torch.argmax(ai_probs).item()
    ai_argmax_model = label_mapping[ai_argmax_index]

    if human_prob > ai_total_prob:
        result_message = f"""
### 🟢 **Human Written**
**Confidence: {human_prob:.2f}%**
This text appears to be written by a human.
---
**Analysis Details:**
- Human probability: {human_prob:.2f}%
- AI probability: {ai_total_prob:.2f}%
- Text length: {len(cleaned_text)} characters
"""
    else:
        result_message = f"""
### 🔴 **AI Generated**
**Confidence: {ai_total_prob:.2f}%**
**Most likely source: {ai_argmax_model}**
This text appears to be generated by an AI model.
---
**Analysis Details:**
- Human probability: {human_prob:.2f}%
- AI probability: {ai_total_prob:.2f}%
- Text length: {len(cleaned_text)} characters
"""

    return result_message