import gradio as gr from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch import re import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from tokenizers.normalizers import Sequence, Replace, Strip from tokenizers import Regex # ------------------------- # Device setup # ------------------------- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # ------------------------- # Model and Tokenizer Setup # ------------------------- model1_path = "modernbert.bin" model2_path = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed12" model3_path = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed22" tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base") def safe_load_model(base_name, weights_path): model = AutoModelForSequenceClassification.from_pretrained(base_name, num_labels=41) state_dict = torch.hub.load_state_dict_from_url(weights_path, map_location=device) if weights_path.startswith("http") else torch.load(weights_path, map_location=device) model.load_state_dict(state_dict) model.to(device).eval() return model print("Loading models...") model_1 = safe_load_model("answerdotai/ModernBERT-base", model1_path) model_2 = safe_load_model("answerdotai/ModernBERT-base", model2_path) model_3 = safe_load_model("answerdotai/ModernBERT-base", model3_path) # ------------------------- # Label Mapping # ------------------------- label_mapping = { 0: '13B', 1: '30B', 2: '65B', 3: '7B', 4: 'GLM130B', 5: 'bloom_7b', 6: 'bloomz', 7: 'cohere', 8: 'davinci', 9: 'dolly', 10: 'dolly-v2-12b', 11: 'flan_t5_base', 12: 'flan_t5_large', 13: 'flan_t5_small', 14: 'flan_t5_xl', 15: 'flan_t5_xxl', 16: 'gemma-7b-it', 17: 'gemma2-9b-it', 18: 'gpt-3.5-turbo', 19: 'gpt-35', 20: 'gpt4', 21: 'gpt4o', 22: 'gpt_j', 23: 'gpt_neox', 24: 'human', 25: 'llama3-70b', 26: 'llama3-8b', 27: 'mixtral-8x7b', 28: 'opt_1.3b', 29: 'opt_125m', 30: 'opt_13b', 31: 'opt_2.7b', 32: 'opt_30b', 33: 'opt_350m', 34: 'opt_6.7b', 35: 'opt_iml_30b', 36: 'opt_iml_max_1.3b', 37: 't0_11b', 38: 't0_3b', 39: 'text-davinci-002', 40: 'text-davinci-003' } # ------------------------- # Text Cleaning # ------------------------- def clean_text(text: str) -> str: text = re.sub(r'\s{2,}', ' ', text) text = re.sub(r'\s+([,.;:?!])', r'\1', text) return text newline_to_space = Replace(Regex(r'\s*\n\s*'), " ") tokenizer.backend_tokenizer.normalizer = Sequence([ tokenizer.backend_tokenizer.normalizer, newline_to_space, Strip() ]) # ------------------------- # Classification Function # ------------------------- def classify_text(text): cleaned_text = clean_text(text) if not cleaned_text.strip(): return "Please enter some text to analyze.", None paragraphs = [p.strip() for p in re.split(r'\n{2,}', cleaned_text) if p.strip()] chunk_scores = [] all_probabilities = [] for i, paragraph in enumerate(paragraphs): inputs = tokenizer(paragraph, return_tensors="pt", truncation=True, padding=True).to(device) with torch.no_grad(): logits_1 = model_1(**inputs).logits logits_2 = model_2(**inputs).logits logits_3 = model_3(**inputs).logits softmax_1 = torch.softmax(logits_1, dim=1) softmax_2 = torch.softmax(logits_2, dim=1) softmax_3 = torch.softmax(logits_3, dim=1) averaged_probabilities = (softmax_1 + softmax_2 + softmax_3) / 3 probabilities = averaged_probabilities[0] all_probabilities.append(probabilities.cpu()) human_prob = probabilities[24].item() ai_probs_clone = probabilities.clone() ai_probs_clone[24] = 0 ai_total_prob = ai_probs_clone.sum().item() total = human_prob + ai_total_prob human_pct = (human_prob / total) * 100 ai_pct = (ai_total_prob / total) * 100 ai_model = label_mapping[torch.argmax(ai_probs_clone).item()] chunk_scores.append({ "paragraph": paragraph[:150] + ("..." if len(paragraph) > 150 else ""), "human": human_pct, "ai": ai_pct, "model": ai_model }) # --- Overall --- avg_human = sum(c["human"] for c in chunk_scores) / len(chunk_scores) avg_ai = sum(c["ai"] for c in chunk_scores) / len(chunk_scores) if avg_human > avg_ai: result_message = f"**Overall Result:** {avg_human:.2f}% Human-written" else: top_model = max(chunk_scores, key=lambda c: c['ai'])['model'] result_message = f"**Overall Result:** {avg_ai:.2f}% AI-generated (likely {top_model})" # --- Paragraph Breakdown --- paragraph_html = "

Paragraph Analysis:

" for idx, c in enumerate(chunk_scores, 1): color = "#4CAF50" if c['human'] > c['ai'] else "#FF5733" paragraph_html += f"""
Paragraph {idx}: {c['human']:.2f}% Human | {c['ai']:.2f}% AI → {c['model']}
{c['paragraph']}
""" # --- Plot --- mean_probs = torch.mean(torch.stack(all_probabilities), dim=0) top_5_probs, top_5_indices = torch.topk(mean_probs, 5) top_5_probs = top_5_probs.cpu().numpy() top_5_labels = [label_mapping[i.item()] for i in top_5_indices] fig, ax = plt.subplots(figsize=(10, 5)) bars = ax.barh(top_5_labels, top_5_probs, color='#4CAF50') ax.set_xlabel('Probability') ax.set_title('Top 5 Model Predictions') ax.invert_yaxis() for bar in bars: width = bar.get_width() ax.text(width + 0.005, bar.get_y() + bar.get_height() / 2, f'{width:.2%}', va='center') plt.tight_layout() return result_message + "

" + paragraph_html, fig # ------------------------- # UI Setup # ------------------------- title = "AI Text Detector" description = """ This tool uses ModernBERT to detect AI-generated text.
Each paragraph is analyzed separately to show which parts are likely AI-generated. """ bottom_text = "**Developed by SzegedAI – Extended by Saber**" AI_texts = [ "Artificial intelligence (AI) is reshaping industries by automating tasks, enhancing decision-making, and driving innovation. From predictive analytics in finance to autonomous vehicles in transportation, AI technologies are becoming integral to daily operations." ] Human_texts = [ "Mathematics has always been a cornerstone of scientific discovery. It provides a precise language for describing natural phenomena, from the orbit of planets to the behavior of subatomic particles." ] iface = gr.Blocks(css=""" @import url('https://fonts.googleapis.com/css2?family=Roboto+Mono:wght@400;700&display=swap'); body { font-family: 'Roboto Mono', sans-serif !important; } .highlight-human { color: #4CAF50; font-weight: bold; } .highlight-ai { color: #FF5733; font-weight: bold; } """) with iface: gr.Markdown(f"# {title}") gr.Markdown(description) text_input = gr.Textbox(label="", placeholder="Paste your article here...", lines=10) analyze_btn = gr.Button("🔍 Analyze", variant="primary") result_output = gr.HTML(label="Result") plot_output = gr.Plot(label="Model Probability Distribution") analyze_btn.click(classify_text, inputs=text_input, outputs=[result_output, plot_output]) with gr.Tab("AI Examples"): gr.Examples(AI_texts, inputs=text_input) with gr.Tab("Human Examples"): gr.Examples(Human_texts, inputs=text_input) gr.Markdown(bottom_text) iface.launch(share=True)