Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import re | |
| # Use GPU if available | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # One tokenizer shared across models | |
| tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base") | |
| # Ensemble model repos (replace with real Hugging Face repos if names differ) | |
| model_names = [ | |
| "mihalykiss/modernbert_2_seed12", | |
| "mihalykiss/modernbert_2_seed22", | |
| "mihalykiss/modernbert_2_seed32" | |
| ] | |
| # Load models directly from Hugging Face | |
| models = [] | |
| for repo in model_names: | |
| m = AutoModelForSequenceClassification.from_pretrained(repo).to(device).eval() | |
| models.append(m) | |
| # Label map | |
| label_mapping = { | |
| 0: '13B', 1: '30B', 2: '65B', 3: '7B', 4: 'GLM130B', 5: 'bloom_7b', | |
| 6: 'bloomz', 7: 'cohere', 8: 'davinci', 9: 'dolly', 10: 'dolly-v2-12b', | |
| 11: 'flan_t5_base', 12: 'flan_t5_large', 13: 'flan_t5_small', | |
| 14: 'flan_t5_xl', 15: 'flan_t5_xxl', 16: 'gemma-7b-it', 17: 'gemma2-9b-it', | |
| 18: 'gpt-3.5-turbo', 19: 'gpt-35', 20: 'gpt4', 21: 'gpt4o', | |
| 22: 'gpt_j', 23: 'gpt_neox', 24: 'human', 25: 'llama3-70b', 26: 'llama3-8b', | |
| 27: 'mixtral-8x7b', 28: 'opt_1.3b', 29: 'opt_125m', 30: 'opt_13b', | |
| 31: 'opt_2.7b', 32: 'opt_30b', 33: 'opt_350m', 34: 'opt_6.7b', | |
| 35: 'opt_iml_30b', 36: 'opt_iml_max_1.3b', 37: 't0_11b', 38: 't0_3b', | |
| 39: 'text-davinci-002', 40: 'text-davinci-003' | |
| } | |
| # Text cleanup | |
| def clean_text(text: str) -> str: | |
| text = re.sub(r"\s{2,}", " ", text) | |
| text = re.sub(r"\s+([,.;:?!])", r"\1", text) | |
| return text.strip() | |
| # Classification function | |
| def classify_text(text): | |
| cleaned_text = clean_text(text) | |
| if not cleaned_text: | |
| return "Please paste some text." | |
| sentences = re.split(r'(?<=[.!?])\s+', cleaned_text) | |
| highlighted = [] | |
| total_ai, total_human = 0, 0 | |
| for sent in sentences: | |
| if not sent.strip(): | |
| continue | |
| inputs = tokenizer(sent, return_tensors="pt", truncation=True, padding=True).to(device) | |
| with torch.no_grad(): | |
| probs_list = [] | |
| for m in models: | |
| logits = m(**inputs).logits | |
| probs_list.append(torch.softmax(logits, dim=1)) | |
| avg_probs = sum(probs_list) / len(probs_list) | |
| probs = avg_probs[0] | |
| # Human class = 24, AI = all others | |
| ai_probs = probs.clone() | |
| ai_probs[24] = 0 | |
| ai_score = ai_probs.sum().item() * 100 | |
| human_score = 100 - ai_score | |
| total_ai += ai_score | |
| total_human += human_score | |
| if ai_score > 20: | |
| highlighted.append(f"<span class='highlight-ai'>{sent}</span>") | |
| else: | |
| highlighted.append(f"<span class='highlight-human'>{sent}</span>") | |
| # Global verdict | |
| if total_human >= total_ai: | |
| verdict = f"<br><br><b>Overall: {(total_human/(total_ai+total_human))*100:.2f}% Human</b>" | |
| else: | |
| verdict = f"<br><br><b>Overall: {(total_ai/(total_ai+total_human))*100:.2f}% AI</b>" | |
| return " ".join(highlighted) + verdict | |
| # Gradio interface with styling | |
| iface = gr.Interface( | |
| fn=classify_text, | |
| inputs=gr.Textbox(lines=6, placeholder="Paste text here..."), | |
| outputs="html", | |
| title="AI Text Detector", | |
| description="Detects AI-generated text using a ModernBERT ensemble. Sentences are highlighted:<br>" | |
| "<span style='color:#FF5733;font-weight:bold;'>AI-like</span> vs " | |
| "<span style='color:#4CAF50;font-weight:bold;'>Human-like</span>." | |
| ) | |
| iface.launch() | |