Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import threading | |
| import torch | |
| import gradio as gr | |
| from huggingface_hub import snapshot_download | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
| MODEL_REPO = "daniel-dona/gemma-3-270m-it" | |
| LOCAL_DIR = os.path.join(os.getcwd(), "local_model") | |
| os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") | |
| os.environ.setdefault("OMP_NUM_THREADS", str(os.cpu_count() or 1)) | |
| os.environ.setdefault("MKL_NUM_THREADS", os.environ["OMP_NUM_THREADS"]) | |
| os.environ.setdefault("OMP_PROC_BIND", "TRUE") | |
| torch.set_num_threads(int(os.environ["OMP_NUM_THREADS"])) | |
| torch.set_num_interop_threads(1) | |
| torch.set_float32_matmul_precision("high") | |
| def ensure_local_model(repo_id: str, local_dir: str, tries: int = 3, sleep_s: float = 3.0) -> str: | |
| os.makedirs(local_dir, exist_ok=True) | |
| for i in range(tries): | |
| try: | |
| snapshot_download( | |
| repo_id=repo_id, | |
| local_dir=local_dir, | |
| local_dir_use_symlinks=False, | |
| resume_download=True, | |
| allow_patterns=["*.json", "*.model", "*.safetensors", "*.bin", "*.txt", "*.py"] | |
| ) | |
| return local_dir | |
| except Exception: | |
| if i == tries - 1: | |
| raise | |
| time.sleep(sleep_s * (2 ** i)) | |
| return local_dir | |
| model_path = ensure_local_model(MODEL_REPO, LOCAL_DIR) | |
| tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True) | |
| gemma_chat_template_simplified = ( | |
| "{% for message in messages %}" | |
| "{% if message['role'] == 'user' %}" | |
| "{{ '<start_of_turn>user\\n' + message['content'] | trim + '<end_of_turn>\\n' }}" | |
| "{% elif message['role'] == 'assistant' %}" | |
| "{{ '<start_of_turn>model\\n' + message['content'] | trim + '<end_of_turn>\\n' }}" | |
| "{% endif %}" | |
| "{% endfor %}" | |
| "{% if add_generation_prompt %}" | |
| "{{ '<start_of_turn>model\\n' }}" | |
| "{% endif %}" | |
| ) | |
| if tokenizer.chat_template is None: | |
| tokenizer.chat_template = gemma_chat_template_simplified | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| local_files_only=True, | |
| torch_dtype=torch.float32, | |
| device_map=None | |
| ) | |
| model.eval() | |
| MODERATION_SYSTEM_PROMPT = ( | |
| "You are a multilingual content moderation classifier. " | |
| "You MUST respond with exactly one lowercase letter: 's' for safe, 'u' for unsafe. " | |
| "No explanations, no punctuation, no extra words. " | |
| "If the message contains hate speech, harassment, sexual content involving minors, " | |
| "extreme violence, self-harm encouragement, or other unsafe material, respond 'u'. " | |
| "Otherwise respond 's'." | |
| ) | |
| def build_prompt(message, max_ctx_tokens=128): | |
| full_user_message = f"{MODERATION_SYSTEM_PROMPT}\n\nUser input: '{message}'" | |
| messages = [{"role": "user", "content": full_user_message}] | |
| text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| while len(tokenizer(text, add_special_tokens=False).input_ids) > max_ctx_tokens and len(full_user_message) > 100: | |
| full_user_message = full_user_message[:-50] | |
| messages[0]['content'] = full_user_message | |
| text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| return text | |
| def enforce_s_u(text: str) -> str: | |
| text_lower = text.strip().lower() | |
| if "u" in text_lower and "s" not in text_lower: | |
| return "u" | |
| if "unsafe" in text_lower: | |
| return "u" | |
| return "s" | |
| def format_classification_result(classification, tokens_per_second, processing_time): | |
| if classification == "s": | |
| status_emoji = "✅" | |
| status_text = "SAFE" | |
| status_color = "#22c55e" | |
| description = "Content appears to be safe and appropriate." | |
| else: | |
| status_emoji = "🚫" | |
| status_text = "UNSAFE" | |
| status_color = "#ef4444" | |
| description = "Content may contain inappropriate or harmful material." | |
| result_html = f""" | |
| <div style="text-align: center; padding: 20px; border-radius: 12px; | |
| background: linear-gradient(135deg, #f8fafc 0%, #e2e8f0 100%); | |
| border: 2px solid {status_color}; margin: 10px 0;"> | |
| <div style="font-size: 48px; margin-bottom: 10px;">{status_emoji}</div> | |
| <div style="font-size: 24px; font-weight: bold; color: {status_color}; margin-bottom: 8px;"> | |
| {status_text} | |
| </div> | |
| <div style="font-size: 16px; color: #64748b; margin-bottom: 15px;"> | |
| {description} | |
| </div> | |
| <div style="display: flex; justify-content: center; gap: 20px; font-size: 14px; color: #475569;"> | |
| <span>⚡ {tokens_per_second:.1f} tok/s</span> | |
| <span>⏱️ {processing_time:.2f}s</span> | |
| </div> | |
| </div> | |
| """ | |
| return result_html | |
| def classify_text_stream(message, max_tokens, temperature, top_p, progress=gr.Progress()): | |
| if not message.strip(): | |
| return format_classification_result("s", 0, 0) | |
| progress(0, desc="Preparing classification...") | |
| text = build_prompt(message) | |
| inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
| do_sample = bool(temperature and temperature > 0.0) | |
| gen_kwargs = dict( | |
| max_new_tokens=max_tokens, | |
| do_sample=do_sample, | |
| top_p=top_p, | |
| temperature=temperature if do_sample else None, | |
| use_cache=True, | |
| eos_token_id=tokenizer.eos_token_id, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| try: | |
| streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True) | |
| except TypeError: | |
| streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True) | |
| thread = threading.Thread( | |
| target=model.generate, | |
| kwargs={**inputs, **{k: v for k, v in gen_kwargs.items() if v is not None}, "streamer": streamer} | |
| ) | |
| partial_text = "" | |
| token_count = 0 | |
| start_time = None | |
| progress(0.3, desc="Processing content...") | |
| with torch.inference_mode(): | |
| thread.start() | |
| try: | |
| for chunk in streamer: | |
| if start_time is None: | |
| start_time = time.time() | |
| partial_text += chunk | |
| token_count += 1 | |
| progress(0.3 + (token_count / max_tokens) * 0.6, desc="Analyzing...") | |
| finally: | |
| thread.join() | |
| final_label = enforce_s_u(partial_text) | |
| end_time = time.time() if start_time else time.time() | |
| duration = max(1e-6, end_time - start_time) | |
| tps = token_count / duration if duration > 0 else 0.0 | |
| progress(1.0, desc="Complete!") | |
| return format_classification_result(final_label, tps, duration) | |
| custom_css = """ | |
| .main-container { | |
| max-width: 1200px !important; | |
| margin: 0 auto !important; | |
| } | |
| .header-section { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| padding: 2rem; | |
| border-radius: 16px; | |
| margin-bottom: 2rem; | |
| color: white; | |
| text-align: center; | |
| } | |
| .classification-panel { | |
| background: white; | |
| border-radius: 16px; | |
| padding: 2rem; | |
| box-shadow: 0 4px 20px rgba(0, 0, 0, 0.1); | |
| border: 1px solid #e2e8f0; | |
| } | |
| .example-card { | |
| transition: transform 0.2s ease; | |
| } | |
| .example-card:hover { | |
| transform: translateY(-2px); | |
| } | |
| .gradio-container { | |
| font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif; | |
| } | |
| .input-section { | |
| background: #f8fafc; | |
| border-radius: 12px; | |
| padding: 1.5rem; | |
| border: 1px solid #e2e8f0; | |
| } | |
| """ | |
| with gr.Blocks(css=custom_css, title="AI Content Moderator", theme=gr.themes.Soft()) as demo: | |
| with gr.Column(elem_classes="main-container"): | |
| gr.HTML(""" | |
| <div class="header-section"> | |
| <h1 style="font-size: 2.5rem; margin-bottom: 0.5rem; font-weight: 700;"> | |
| 🛡️ AI Content Moderator | |
| </h1> | |
| <p style="font-size: 1.2rem; opacity: 0.9; margin: 0;"> | |
| Advanced multilingual content classification powered by AI | |
| </p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| with gr.Group(elem_classes="input-section"): | |
| gr.Markdown("### 📝 Content Analysis") | |
| text_input = gr.Textbox( | |
| label="Text to Analyze", | |
| placeholder="Enter any text in any language for content moderation analysis...", | |
| lines=6, | |
| max_lines=10, | |
| show_label=False | |
| ) | |
| with gr.Row(): | |
| classify_btn = gr.Button( | |
| "🔍 Analyze Content", | |
| variant="primary", | |
| size="lg", | |
| scale=2 | |
| ) | |
| clear_btn = gr.Button( | |
| "🗑️ Clear", | |
| variant="secondary", | |
| size="lg", | |
| scale=1 | |
| ) | |
| with gr.Column(scale=2): | |
| with gr.Group(elem_classes="classification-panel"): | |
| gr.Markdown("### 📊 Classification Result") | |
| result_display = gr.HTML( | |
| value=format_classification_result("s", 0, 0), | |
| label="Result" | |
| ) | |
| with gr.Accordion("⚙️ Advanced Configuration", open=False): | |
| with gr.Row(): | |
| max_tokens_slider = gr.Slider( | |
| minimum=1, maximum=10, value=3, step=1, | |
| label="Max Tokens", | |
| info="Maximum number of tokens to generate" | |
| ) | |
| temp_slider = gr.Slider( | |
| minimum=0.0, maximum=1.0, value=0.1, step=0.1, | |
| label="Temperature", | |
| info="Controls randomness in generation" | |
| ) | |
| top_p_slider = gr.Slider( | |
| minimum=0.1, maximum=1.0, value=0.95, step=0.05, | |
| label="Top-p", | |
| info="Nucleus sampling parameter" | |
| ) | |
| gr.Markdown("### 💡 Try These Examples") | |
| example_data = [ | |
| ["Hello, how are you today? I hope you're having a wonderful time!"], | |
| ["I hate you and I will find you and hurt you badly."], | |
| ["C'est une belle journée pour apprendre la programmation et l'intelligence artificielle."], | |
| ["I can't take this anymore. I want to end everything and disappear forever."], | |
| ["¡Hola! Me encanta aprender nuevos idiomas y conocer diferentes culturas."], | |
| ["You're absolutely worthless and nobody will ever love someone like you."] | |
| ] | |
| examples = gr.Examples( | |
| examples=example_data, | |
| inputs=text_input, | |
| examples_per_page=6 | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| <div style="text-align: center; padding: 1rem; color: #64748b; font-size: 0.9rem;"> | |
| <p><strong>🌍 Multilingual Support:</strong> English, Spanish, French, German, and many more languages</p> | |
| <p><strong>🚀 Real-time Analysis:</strong> Fast content classification with detailed feedback</p> | |
| <p><strong>🔒 Privacy First:</strong> All processing happens locally on your machine</p> | |
| </div> | |
| """) | |
| classify_btn.click( | |
| fn=classify_text_stream, | |
| inputs=[text_input, max_tokens_slider, temp_slider, top_p_slider], | |
| outputs=result_display, | |
| show_progress=True | |
| ) | |
| clear_btn.click( | |
| fn=lambda: ("", format_classification_result("s", 0, 0)), | |
| outputs=[text_input, result_display] | |
| ) | |
| if __name__ == "__main__": | |
| with torch.inference_mode(): | |
| _ = model.generate( | |
| **tokenizer(["Hi"], return_tensors="pt").to(model.device), | |
| max_new_tokens=1, do_sample=False, use_cache=True | |
| ) | |
| print("🚀 Starting AI Content Moderator...") | |
| demo.queue(max_size=64).launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True | |
| ) |