MGT-Detection / app.py
ziadmostafa's picture
added app files
640b4b2
import json
import random
from pathlib import Path
import gradio as gr
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
# Constants
MIN_WORDS = 50
MAX_WORDS = 500
SAMPLE_JSON_PATH = Path('samples.json')
# Load models
def load_model(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
return pipeline('text-classification', model=model, tokenizer=tokenizer, truncation=True, max_length=512, top_k=4)
classifier = load_model("ziadmostafa/MGT-Detection_deberta-base")
# Load sample essays
with open(SAMPLE_JSON_PATH, 'r') as f:
demo_essays = json.load(f)
# Global variable to store the current essay index
current_essay_index = None
TEXT_CLASS_MAPPING = {
'LABEL_0': 'Human-Written',
'LABEL_2': 'Machine-Generated'
}
def process_result(text):
result = classifier(text)[0]
labels = [TEXT_CLASS_MAPPING[x['label']] for x in result if x['label'] in TEXT_CLASS_MAPPING]
scores = list(np.array([x['score'] for x in result if x['label'] in TEXT_CLASS_MAPPING]))
final_results = dict(zip(labels, scores))
# Return only the label with the highest score
return max(final_results, key=final_results.get)
def update_result(name):
if name == '':
return ""
return process_result(name)
def active_button(input_text):
if not (50 <= len(input_text.split()) <= 500):
return gr.Button("Check Origin", variant="primary", interactive=False)
return gr.Button("Check Origin", variant="primary", interactive=True)
def clear_inputs():
return "", gr.Button("Check Origin", variant="primary", interactive=False)
def count_words(text):
return f'{len(text.split())}/500 words (Minimum 50 words)'
css = """
body, .gradio-container {
font-family: Arial, sans-serif;
}
.gr-input, .gr-textarea {
}
.class-intro {
padding: 15px;
margin-bottom: 20px;
border-radius: 5px;
}
.class-intro h2 {
margin-top: 0;
}
.class-intro p {
margin-bottom: 5px;
}
"""
class_intro_html = """
<div class="class-intro">
<h2>Text Classes</h2>
<p><strong>Human-Written:</strong> Original text created by humans.</p>
<p><strong>Machine-Generated:</strong> Text created by AI from basic prompts, without style instructions.</p>
</div>
"""
with gr.Blocks(css=css) as demo:
gr.Markdown("""<h1><centre>Machine Generated Text Detection</center></h1>""")
gr.HTML(class_intro_html)
with gr.Row():
input_text = gr.Textbox(placeholder="Paste your text here...", label="Text", lines=10, max_lines=15)
with gr.Row():
wc = gr.Markdown("0/500 words (Minimum 50 words)")
with gr.Row():
check_button = gr.Button("Check Origin", variant="primary", interactive=False)
clear_button = gr.ClearButton([input_text], variant="stop")
out = gr.Label(label='Result')
clear_button.add(out)
check_button.click(fn=update_result, inputs=[input_text], outputs=out)
input_text.change(count_words, input_text, wc, show_progress=False)
input_text.input(
active_button,
[input_text],
[check_button],
)
clear_button.click(
clear_inputs,
inputs=[],
outputs=[input_text, check_button],
)
demo.launch(share=False)