Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| import re | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from tokenizers.normalizers import Sequence, Replace, Strip | |
| from tokenizers import Regex | |
| # ------------------------- | |
| # Device setup | |
| # ------------------------- | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # ------------------------- | |
| # Model and Tokenizer Setup | |
| # ------------------------- | |
| model1_path = "modernbert.bin" | |
| model2_path = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed12" | |
| model3_path = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed22" | |
| tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base") | |
| def safe_load_model(base_name, weights_path): | |
| model = AutoModelForSequenceClassification.from_pretrained(base_name, num_labels=41) | |
| state_dict = torch.hub.load_state_dict_from_url(weights_path, map_location=device) if weights_path.startswith("http") else torch.load(weights_path, map_location=device) | |
| model.load_state_dict(state_dict) | |
| model.to(device).eval() | |
| return model | |
| print("Loading models...") | |
| model_1 = safe_load_model("answerdotai/ModernBERT-base", model1_path) | |
| model_2 = safe_load_model("answerdotai/ModernBERT-base", model2_path) | |
| model_3 = safe_load_model("answerdotai/ModernBERT-base", model3_path) | |
| # ------------------------- | |
| # Label Mapping | |
| # ------------------------- | |
| label_mapping = { | |
| 0: '13B', 1: '30B', 2: '65B', 3: '7B', 4: 'GLM130B', 5: 'bloom_7b', | |
| 6: 'bloomz', 7: 'cohere', 8: 'davinci', 9: 'dolly', 10: 'dolly-v2-12b', | |
| 11: 'flan_t5_base', 12: 'flan_t5_large', 13: 'flan_t5_small', | |
| 14: 'flan_t5_xl', 15: 'flan_t5_xxl', 16: 'gemma-7b-it', 17: 'gemma2-9b-it', | |
| 18: 'gpt-3.5-turbo', 19: 'gpt-35', 20: 'gpt4', 21: 'gpt4o', | |
| 22: 'gpt_j', 23: 'gpt_neox', 24: 'human', 25: 'llama3-70b', 26: 'llama3-8b', | |
| 27: 'mixtral-8x7b', 28: 'opt_1.3b', 29: 'opt_125m', 30: 'opt_13b', | |
| 31: 'opt_2.7b', 32: 'opt_30b', 33: 'opt_350m', 34: 'opt_6.7b', | |
| 35: 'opt_iml_30b', 36: 'opt_iml_max_1.3b', 37: 't0_11b', 38: 't0_3b', | |
| 39: 'text-davinci-002', 40: 'text-davinci-003' | |
| } | |
| # ------------------------- | |
| # Text Cleaning | |
| # ------------------------- | |
| def clean_text(text: str) -> str: | |
| text = re.sub(r'\s{2,}', ' ', text) | |
| text = re.sub(r'\s+([,.;:?!])', r'\1', text) | |
| return text | |
| newline_to_space = Replace(Regex(r'\s*\n\s*'), " ") | |
| tokenizer.backend_tokenizer.normalizer = Sequence([ | |
| tokenizer.backend_tokenizer.normalizer, | |
| newline_to_space, | |
| Strip() | |
| ]) | |
| # ------------------------- | |
| # Classification Function | |
| # ------------------------- | |
| def classify_text(text): | |
| cleaned_text = clean_text(text) | |
| if not cleaned_text.strip(): | |
| return "<b style='color:red;'>Please enter some text to analyze.</b>", None | |
| paragraphs = [p.strip() for p in re.split(r'\n{2,}', cleaned_text) if p.strip()] | |
| chunk_scores = [] | |
| all_probabilities = [] | |
| for i, paragraph in enumerate(paragraphs): | |
| inputs = tokenizer(paragraph, return_tensors="pt", truncation=True, padding=True).to(device) | |
| with torch.no_grad(): | |
| logits_1 = model_1(**inputs).logits | |
| logits_2 = model_2(**inputs).logits | |
| logits_3 = model_3(**inputs).logits | |
| softmax_1 = torch.softmax(logits_1, dim=1) | |
| softmax_2 = torch.softmax(logits_2, dim=1) | |
| softmax_3 = torch.softmax(logits_3, dim=1) | |
| averaged_probabilities = (softmax_1 + softmax_2 + softmax_3) / 3 | |
| probabilities = averaged_probabilities[0] | |
| all_probabilities.append(probabilities.cpu()) | |
| human_prob = probabilities[24].item() | |
| ai_probs_clone = probabilities.clone() | |
| ai_probs_clone[24] = 0 | |
| ai_total_prob = ai_probs_clone.sum().item() | |
| total = human_prob + ai_total_prob | |
| human_pct = (human_prob / total) * 100 | |
| ai_pct = (ai_total_prob / total) * 100 | |
| ai_model = label_mapping[torch.argmax(ai_probs_clone).item()] | |
| chunk_scores.append({ | |
| "paragraph": paragraph[:150] + ("..." if len(paragraph) > 150 else ""), | |
| "human": human_pct, | |
| "ai": ai_pct, | |
| "model": ai_model | |
| }) | |
| # --- Overall --- | |
| avg_human = sum(c["human"] for c in chunk_scores) / len(chunk_scores) | |
| avg_ai = sum(c["ai"] for c in chunk_scores) / len(chunk_scores) | |
| if avg_human > avg_ai: | |
| result_message = f"**Overall Result:** <span class='highlight-human'>{avg_human:.2f}% Human-written</span>" | |
| else: | |
| top_model = max(chunk_scores, key=lambda c: c['ai'])['model'] | |
| result_message = f"**Overall Result:** <span class='highlight-ai'>{avg_ai:.2f}% AI-generated (likely {top_model})</span>" | |
| # --- Paragraph Breakdown --- | |
| paragraph_html = "<h3>Paragraph Analysis:</h3>" | |
| for idx, c in enumerate(chunk_scores, 1): | |
| color = "#4CAF50" if c['human'] > c['ai'] else "#FF5733" | |
| paragraph_html += f""" | |
| <div style='margin-bottom:10px; border-left:4px solid {color}; padding-left:10px;'> | |
| <b>Paragraph {idx}</b>: {c['human']:.2f}% Human | {c['ai']:.2f}% AI → <i>{c['model']}</i><br> | |
| <small>{c['paragraph']}</small></div> | |
| """ | |
| # --- Plot --- | |
| mean_probs = torch.mean(torch.stack(all_probabilities), dim=0) | |
| top_5_probs, top_5_indices = torch.topk(mean_probs, 5) | |
| top_5_probs = top_5_probs.cpu().numpy() | |
| top_5_labels = [label_mapping[i.item()] for i in top_5_indices] | |
| fig, ax = plt.subplots(figsize=(10, 5)) | |
| bars = ax.barh(top_5_labels, top_5_probs, color='#4CAF50') | |
| ax.set_xlabel('Probability') | |
| ax.set_title('Top 5 Model Predictions') | |
| ax.invert_yaxis() | |
| for bar in bars: | |
| width = bar.get_width() | |
| ax.text(width + 0.005, bar.get_y() + bar.get_height() / 2, f'{width:.2%}', va='center') | |
| plt.tight_layout() | |
| return result_message + "<br><br>" + paragraph_html, fig | |
| # ------------------------- | |
| # UI Setup | |
| # ------------------------- | |
| title = "AI Text Detector" | |
| description = """ | |
| This tool uses <b>ModernBERT</b> to detect AI-generated text.<br> | |
| Each paragraph is analyzed separately to show which parts are likely AI-generated. | |
| """ | |
| bottom_text = "**Developed by SzegedAI – Extended by Saber**" | |
| AI_texts = [ | |
| "Artificial intelligence (AI) is reshaping industries by automating tasks, enhancing decision-making, and driving innovation. From predictive analytics in finance to autonomous vehicles in transportation, AI technologies are becoming integral to daily operations." | |
| ] | |
| Human_texts = [ | |
| "Mathematics has always been a cornerstone of scientific discovery. It provides a precise language for describing natural phenomena, from the orbit of planets to the behavior of subatomic particles." | |
| ] | |
| iface = gr.Blocks(css=""" | |
| @import url('https://fonts.googleapis.com/css2?family=Roboto+Mono:wght@400;700&display=swap'); | |
| body { font-family: 'Roboto Mono', sans-serif !important; } | |
| .highlight-human { color: #4CAF50; font-weight: bold; } | |
| .highlight-ai { color: #FF5733; font-weight: bold; } | |
| """) | |
| with iface: | |
| gr.Markdown(f"# {title}") | |
| gr.Markdown(description) | |
| text_input = gr.Textbox(label="", placeholder="Paste your article here...", lines=10) | |
| analyze_btn = gr.Button("🔍 Analyze", variant="primary") | |
| result_output = gr.HTML(label="Result") | |
| plot_output = gr.Plot(label="Model Probability Distribution") | |
| analyze_btn.click(classify_text, inputs=text_input, outputs=[result_output, plot_output]) | |
| with gr.Tab("AI Examples"): | |
| gr.Examples(AI_texts, inputs=text_input) | |
| with gr.Tab("Human Examples"): | |
| gr.Examples(Human_texts, inputs=text_input) | |
| gr.Markdown(bottom_text) | |
| iface.launch(share=True) | |