Spaces:
Build error
Build error
| import streamlit as st | |
| import torch | |
| import torch.hub | |
| import re | |
| import os | |
| import time | |
| # --- Set Page Config First --- | |
| st.set_page_config( | |
| page_title="AI Text Detector", | |
| layout="centered", | |
| initial_sidebar_state="collapsed" | |
| ) | |
| # --- Improved CSS for a cleaner UI --- | |
| st.markdown(""" | |
| <style> | |
| /* Modern clean font for the entire app */ | |
| @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap'); | |
| html, body, [class*="css"] { | |
| font-family: 'Inter', sans-serif; | |
| } | |
| /* Header styling */ | |
| h1 { | |
| font-weight: 700; | |
| color: #1E3A8A; | |
| padding-bottom: 1rem; | |
| border-bottom: 2px solid #E5E7EB; | |
| margin-bottom: 2rem; | |
| } | |
| /* Text area styling */ | |
| .stTextArea textarea { | |
| border: 1px solid #D1D5DB; | |
| border-radius: 8px; | |
| font-size: 16px; | |
| padding: 12px; | |
| background-color: #F9FAFB; | |
| box-shadow: 0 1px 2px rgba(0, 0, 0, 0.05); | |
| transition: border-color 0.15s ease-in-out, box-shadow 0.15s ease-in-out; | |
| } | |
| .stTextArea textarea:focus { | |
| border-color: #3B82F6; | |
| box-shadow: 0 0 0 3px rgba(59, 130, 246, 0.3); | |
| outline: none; | |
| } | |
| /* Button styling */ | |
| .stButton button { | |
| border-radius: 8px; | |
| font-weight: 600; | |
| padding: 10px 16px; | |
| background-color: #2563EB; | |
| color: white; | |
| border: none; | |
| width: 100%; | |
| transition: background-color 0.2s ease; | |
| } | |
| .stButton button:hover { | |
| background-color: #1D4ED8; | |
| } | |
| /* Result box styling */ | |
| .result-box { | |
| border-radius: 8px; | |
| padding: 20px; | |
| margin-top: 24px; | |
| text-align: center; | |
| background-color: white; | |
| box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1), 0 1px 2px rgba(0, 0, 0, 0.06); | |
| border: 1px solid #E5E7EB; | |
| } | |
| /* Result highlights */ | |
| .highlight-human { | |
| color: #059669; | |
| font-weight: 600; | |
| background: rgba(5, 150, 105, 0.1); | |
| padding: 4px 10px; | |
| border-radius: 8px; | |
| display: inline-block; | |
| } | |
| .highlight-ai { | |
| color: #DC2626; | |
| font-weight: 600; | |
| background: rgba(220, 38, 38, 0.1); | |
| padding: 4px 10px; | |
| border-radius: 8px; | |
| display: inline-block; | |
| } | |
| /* Footer styling */ | |
| .footer { | |
| text-align: center; | |
| margin-top: 40px; | |
| padding-top: 20px; | |
| border-top: 1px solid #E5E7EB; | |
| color: #6B7280; | |
| font-size: 14px; | |
| } | |
| /* Progress bar styling */ | |
| .stProgress > div > div { | |
| background-color: #2563EB; | |
| } | |
| /* General spacing */ | |
| .block-container { | |
| padding-top: 2rem; | |
| padding-bottom: 2rem; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # --- Configuration --- | |
| MODEL1_PATH = "modernbert.bin" | |
| MODEL2_URL = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed12" | |
| MODEL3_URL = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed22" | |
| BASE_MODEL = "answerdotai/ModernBERT-base" | |
| NUM_LABELS = 41 | |
| HUMAN_LABEL_INDEX = 24 | |
| DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # --- Model Loading Functions --- | |
| def load_tokenizer(model_name): | |
| from transformers import AutoTokenizer | |
| return AutoTokenizer.from_pretrained(model_name) | |
| def load_model(model_path_or_url, base_model, num_labels, is_url=False, _device=DEVICE): | |
| from transformers import AutoModelForSequenceClassification | |
| # Load base model architecture | |
| model = AutoModelForSequenceClassification.from_pretrained(base_model, num_labels=num_labels) | |
| try: | |
| # Load weights | |
| if is_url: | |
| state_dict = torch.hub.load_state_dict_from_url(model_path_or_url, map_location=_device, progress=False) | |
| else: | |
| if not os.path.exists(model_path_or_url): | |
| return None | |
| state_dict = torch.load(model_path_or_url, map_location=_device, weights_only=False) | |
| model.load_state_dict(state_dict) | |
| model.to(_device).eval() | |
| return model | |
| except Exception: | |
| return None | |
| # --- Text Processing Functions --- | |
| def clean_text(text): | |
| if not isinstance(text, str): | |
| return "" | |
| text = text.replace("\r\n", "\n").replace("\r", "\n") | |
| text = re.sub(r"\n\s*\n+", "\n\n", text) | |
| text = re.sub(r"[ \t]+", " ", text) | |
| text = re.sub(r"(\w+)-\s*\n\s*(\w+)", r"\1\2", text) | |
| text = re.sub(r"(?<!\n)\n(?!\n)", " ", text) | |
| return text.strip() | |
| def classify_text(text, tokenizer, model_1, model_2, model_3, device, label_mapping, human_label_index): | |
| if not all([model_1, model_2, model_3, tokenizer]): | |
| return {"error": True, "message": "Models failed to load properly."} | |
| cleaned_text = clean_text(text) | |
| if not cleaned_text: | |
| return None | |
| try: | |
| inputs = tokenizer( | |
| cleaned_text, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding=True, | |
| max_length=tokenizer.model_max_length | |
| ).to(device) | |
| with torch.no_grad(): | |
| logits_1 = model_1(**inputs).logits | |
| logits_2 = model_2(**inputs).logits | |
| logits_3 = model_3(**inputs).logits | |
| softmax_1 = torch.softmax(logits_1, dim=1) | |
| softmax_2 = torch.softmax(logits_2, dim=1) | |
| softmax_3 = torch.softmax(logits_3, dim=1) | |
| averaged_probabilities = (softmax_1 + softmax_2 + softmax_3) / 3 | |
| probabilities = averaged_probabilities[0].cpu() | |
| if not (0 <= human_label_index < len(probabilities)): | |
| return {"error": True, "message": "Configuration error."} | |
| human_prob = probabilities[human_label_index].item() * 100 | |
| mask = torch.ones_like(probabilities, dtype=torch.bool) | |
| mask[human_label_index] = False | |
| ai_total_prob = probabilities[mask].sum().item() * 100 | |
| ai_probs_only = probabilities.clone() | |
| ai_probs_only[human_label_index] = -float('inf') | |
| ai_argmax_index = torch.argmax(ai_probs_only).item() | |
| ai_argmax_model = label_mapping.get(ai_argmax_index, f"Unknown AI (Index {ai_argmax_index})") | |
| if human_prob >= ai_total_prob: | |
| return {"is_human": True, "probability": human_prob, "model": "Human"} | |
| else: | |
| return {"is_human": False, "probability": ai_total_prob, "model": ai_argmax_model} | |
| except Exception as e: | |
| return {"error": True, "message": f"Analysis failed: {str(e)}"} | |
| # --- Label Mapping --- | |
| LABEL_MAPPING = { | |
| 0: '13B', 1: '30B', 2: '65B', 3: '7B', 4: 'GLM130B', 5: 'bloom_7b', | |
| 6: 'bloomz', 7: 'cohere', 8: 'davinci', 9: 'dolly', 10: 'dolly-v2-12b', | |
| 11: 'flan_t5_base', 12: 'flan_t5_large', 13: 'flan_t5_small', | |
| 14: 'flan_t5_xl', 15: 'flan_t5_xxl', 16: 'gemma-7b-it', 17: 'gemma2-9b-it', | |
| 18: 'gpt-3.5-turbo', 19: 'gpt-35', 20: 'gpt4', 21: 'gpt4o', | |
| 22: 'gpt_j', 23: 'gpt_neox', 24: 'human', 25: 'llama3-70b', 26: 'llama3-8b', | |
| 27: 'mixtral-8x7b', 28: 'opt_1.3b', 29: 'opt_125m', 30: 'opt_13b', | |
| 31: 'opt_2.7b', 32: 'opt_30b', 33: 'opt_350m', 34: 'opt_6.7b', | |
| 35: 'opt_iml_30b', 36: 'opt_iml_max_1.3b', 37: 't0_11b', 38: 't0_3b', | |
| 39: 'text-davinci-002', 40: 'text-davinci-003' | |
| } | |
| # --- Main UI --- | |
| st.title("AI Text Detector") | |
| # Initialization with a progress bar | |
| with st.spinner(""): | |
| # Create a progress bar | |
| progress_bar = st.progress(0) | |
| st.info("Initializing AI detection models...") | |
| # Step 1: Load tokenizer | |
| progress_bar.progress(20) | |
| time.sleep(0.5) # Small delay for visual feedback | |
| TOKENIZER = load_tokenizer(BASE_MODEL) | |
| # Step 2: Load first model | |
| progress_bar.progress(40) | |
| time.sleep(0.5) # Small delay for visual feedback | |
| MODEL_1 = load_model(MODEL1_PATH, BASE_MODEL, NUM_LABELS, is_url=False, _device=DEVICE) | |
| # Step 3: Load second model | |
| progress_bar.progress(60) | |
| time.sleep(0.5) # Small delay for visual feedback | |
| MODEL_2 = load_model(MODEL2_URL, BASE_MODEL, NUM_LABELS, is_url=True, _device=DEVICE) | |
| # Step 4: Load third model | |
| progress_bar.progress(80) | |
| time.sleep(0.5) # Small delay for visual feedback | |
| MODEL_3 = load_model(MODEL3_URL, BASE_MODEL, NUM_LABELS, is_url=True, _device=DEVICE) | |
| # Complete initialization | |
| progress_bar.progress(100) | |
| time.sleep(0.5) # Small delay for visual feedback | |
| # Clear the initialization messages | |
| st.empty() | |
| # Check if models loaded successfully | |
| if not all([TOKENIZER, MODEL_1, MODEL_2, MODEL_3]): | |
| st.error("Failed to initialize one or more AI detection models. Please try refreshing the page.") | |
| st.stop() | |
| # Input area | |
| input_text = st.text_area( | |
| label="Enter text to analyze:", | |
| placeholder="Type or paste your content here for AI detection analysis...", | |
| height=200, | |
| key="text_input" | |
| ) | |
| # Analyze button and output | |
| analyze_button = st.button("Analyze Text", key="analyze_button") | |
| result_placeholder = st.empty() | |
| if analyze_button: | |
| if input_text and input_text.strip(): | |
| with st.spinner('Analyzing text...'): | |
| classification_result = classify_text( | |
| input_text, | |
| TOKENIZER, | |
| MODEL_1, | |
| MODEL_2, | |
| MODEL_3, | |
| DEVICE, | |
| LABEL_MAPPING, | |
| HUMAN_LABEL_INDEX | |
| ) | |
| # Display result | |
| if classification_result is None: | |
| result_placeholder.warning("Please enter some text to analyze.") | |
| elif classification_result.get("error"): | |
| error_message = classification_result.get("message", "An unknown error occurred during analysis.") | |
| result_placeholder.error(f"Analysis Error: {error_message}") | |
| elif classification_result["is_human"]: | |
| prob = classification_result['probability'] | |
| result_html = ( | |
| f"<div class='result-box'>" | |
| f"<b>The text is</b> <span class='highlight-human'><b>{prob:.2f}%</b> likely <b>Human written</b>.</span>" | |
| f"</div>" | |
| ) | |
| result_placeholder.markdown(result_html, unsafe_allow_html=True) | |
| else: # AI generated | |
| prob = classification_result['probability'] | |
| model_name = classification_result['model'] | |
| result_html = ( | |
| f"<div class='result-box'>" | |
| f"<b>The text is</b> <span class='highlight-ai'><b>{prob:.2f}%</b> likely <b>AI generated</b>.</span><br><br>" | |
| f"<b>Most Likely AI Model: {model_name}</b>" | |
| f"</div>" | |
| ) | |
| result_placeholder.markdown(result_html, unsafe_allow_html=True) | |
| else: | |
| result_placeholder.warning("Please enter some text to analyze.") | |
| # Footer | |
| st.markdown("<div class='footer'>Developed by Eeman Majumder</div>", unsafe_allow_html=True) |