| import streamlit as st |
| import os |
| import torch |
| from transformers import ( |
| AutoTokenizer, AutoModelForSequenceClassification, |
| pipeline, BlipProcessor, BlipForConditionalGeneration |
| ) |
| from peft import PeftModel |
| from PIL import Image |
| import requests |
|
|
| |
| id2label = { |
| 0: "Child Sexual Exploitation", |
| 1: "Elections", |
| 2: "Non-Violent Crimes", |
| 3: "Safe", |
| 4: "Sex-Related Crimes", |
| 5: "Suicide & Self-Harm", |
| 6: "Unknown S-Type", |
| 7: "Violent Crimes", |
| 8: "unsafe" |
| } |
|
|
| |
| @st.cache_resource |
| def load_caption_model(): |
| processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") |
| model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") |
| return processor, model |
|
|
| def caption_image(img): |
| processor, model = load_caption_model() |
| |
| inputs = processor(images=img, return_tensors="pt") |
| pixel_values = inputs["pixel_values"] |
| out = model.generate(pixel_values=pixel_values) |
| caption = processor.decode(out[0], skip_special_tokens=True) |
| return caption |
|
|
| |
| @st.cache_resource |
| def load_toxic_classifier(): |
| |
| model_dir = "NightPrince/peft-distilbert-toxic-classifier" |
| base_model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=9) |
| model = PeftModel.from_pretrained(base_model, model_dir) |
| tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") |
| |
| try: |
| pipe = pipeline("text-classification", model=model, tokenizer=tokenizer, return_all_scores=True) |
| return pipe |
| except Exception: |
| def manual_pipe(text): |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| scores = outputs.logits.softmax(dim=-1).squeeze().tolist() |
| return [[{"label": str(i), "score": float(score)} for i, score in enumerate(scores)]] |
| return manual_pipe |
|
|
| def classify_toxicity(text_input, caption): |
| pipe = load_toxic_classifier() |
| full_input = text_input + " [SEP] " + caption |
| preds = pipe(full_input) |
| |
| if isinstance(preds, list) and len(preds) > 0 and isinstance(preds[0], list): |
| preds = preds[0] |
| if not isinstance(preds, list) or len(preds) == 0 or not isinstance(preds[0], dict): |
| return "Unknown", 0.0, "No prediction" |
| preds_sorted = sorted(preds, key=lambda x: x.get('score', 0), reverse=True) |
| top_label = preds_sorted[0].get('label', '0') |
| top_score = preds_sorted[0].get('score', 0.0) |
| label_id = int(top_label.split("_")[-1]) if "_" in top_label else int(top_label) |
| final_label = id2label.get(label_id, "Unknown") |
| scores_table = "\n".join( |
| [f"{id2label.get(int(item.get('label', '0').split('_')[-1]), 'Unknown')}: {round(float(item.get('score', 0))*100, 2)}%" for item in preds] |
| ) |
| return final_label, top_score, scores_table |
|
|
| |
|
|
| def llama_guard_filter(text): |
| headers = { |
| "Authorization": "Bearer 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", |
| "Content-Type": "application/json" |
| } |
| data = { |
| "model": "meta-llama/llama-guard-3-8b", |
| "max_tokens": 25, |
| "messages": [ |
| { |
| "role": "system", |
| "content": ( |
| "You are a strict safety classifier. Respond with only one word: either 'safe' or 'unsafe'. " |
| "Do not add punctuation, explanation, or formatting." |
| ) |
| }, |
| {"role": "user", "content": text} |
| ] |
| } |
| try: |
| response = requests.post("https://openrouter.ai/api/v1/chat/completions", headers=headers, json=data, timeout=10) |
| content = response.json()["choices"][0]["message"]["content"].strip().lower() |
| if "safe" in content: |
| return "safe" |
| else: |
| return "unsafe" |
| except Exception as e: |
| print("Llama Guard API Error:", e) |
| return "api_error" |
|
|
| |
| st.set_page_config(page_title="Toxic Moderation System", layout="centered") |
| st.title("🛡️ Dual-Stage Toxic Moderation") |
| st.markdown("Moderate text and images using **Llama Guard** + **DistilBERT-LoRA**.\n\n- Stage 1: Hard Safety Filter (Llama Guard)\n- Stage 2: Fine Toxic Classifier (LoRA DistilBERT)") |
|
|
| text_input = st.text_area("✏️ Enter a text message", height=150) |
| uploaded_image = st.file_uploader("📷 Upload an image", type=["jpg", "jpeg", "png"]) |
|
|
| image_caption = "" |
| if uploaded_image: |
| image = Image.open(uploaded_image) |
| st.image(image, caption="Uploaded Image", use_column_width=True) |
| with st.spinner("🔍 Generating caption with BLIP..."): |
| image_caption = caption_image(image) |
| st.success(f"📝 Caption: `{image_caption}`") |
|
|
|
|
|
|
| if st.button("🚀 Run Moderation"): |
| full_text = text_input + " [SEP] " + image_caption |
| with st.spinner("🛡️ Stage 1: Llama Guard..."): |
| safety = llama_guard_filter(full_text) |
|
|
| if safety == "unsafe": |
| st.error("❌ Llama Guard flagged this content as **UNSAFE**.\nModeration stopped.") |
| elif safety == "safe": |
| st.success("✅ Safe by Llama Guard. Proceeding to classifier...") |
| with st.spinner("🧠 Stage 2: DistilBERT Toxic Classifier..."): |
| label, score, scores = classify_toxicity(text_input, image_caption) |
| st.markdown(f"### 🔍 Prediction: `{label}` ({round(score*100, 2)}%)") |
| st.text("📊 Class Probabilities:\n" + scores) |
| else: |
| st.warning(f"Llama Guard API returned: {safety}. Proceed with caution.") |
|
|
| |