Dual-Stage-Toxic-Moderation / app_streamlit.py
NightPrince's picture
Update app_streamlit.py
4179758 verified
import streamlit as st
import os
import torch
from transformers import (
AutoTokenizer, AutoModelForSequenceClassification,
pipeline, BlipProcessor, BlipForConditionalGeneration
)
from peft import PeftModel
from PIL import Image
import requests
# 1️⃣ Setup label mapping
id2label = {
0: "Child Sexual Exploitation",
1: "Elections",
2: "Non-Violent Crimes",
3: "Safe",
4: "Sex-Related Crimes",
5: "Suicide & Self-Harm",
6: "Unknown S-Type",
7: "Violent Crimes",
8: "unsafe"
}
# 2️⃣ Load BLIP captioning model
@st.cache_resource
def load_caption_model():
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
return processor, model
def caption_image(img):
processor, model = load_caption_model()
# BLIP expects pixel_values, not input_ids
inputs = processor(images=img, return_tensors="pt")
pixel_values = inputs["pixel_values"]
out = model.generate(pixel_values=pixel_values)
caption = processor.decode(out[0], skip_special_tokens=True)
return caption
# 3️⃣ Load your DistilBERT+LoRA classifier
@st.cache_resource
def load_toxic_classifier():
# Load local fine-tuned model (full absolute path)
model_dir = "NightPrince/peft-distilbert-toxic-classifier"
base_model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=9)
model = PeftModel.from_pretrained(base_model, model_dir)
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
# Try pipeline, fallback to manual inference
try:
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer, return_all_scores=True)
return pipe
except Exception:
def manual_pipe(text):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
scores = outputs.logits.softmax(dim=-1).squeeze().tolist()
return [[{"label": str(i), "score": float(score)} for i, score in enumerate(scores)]]
return manual_pipe
def classify_toxicity(text_input, caption):
pipe = load_toxic_classifier()
full_input = text_input + " [SEP] " + caption
preds = pipe(full_input)
# Handle both pipeline and manual_pipe output
if isinstance(preds, list) and len(preds) > 0 and isinstance(preds[0], list):
preds = preds[0]
if not isinstance(preds, list) or len(preds) == 0 or not isinstance(preds[0], dict):
return "Unknown", 0.0, "No prediction"
preds_sorted = sorted(preds, key=lambda x: x.get('score', 0), reverse=True)
top_label = preds_sorted[0].get('label', '0')
top_score = preds_sorted[0].get('score', 0.0)
label_id = int(top_label.split("_")[-1]) if "_" in top_label else int(top_label)
final_label = id2label.get(label_id, "Unknown")
scores_table = "\n".join(
[f"{id2label.get(int(item.get('label', '0').split('_')[-1]), 'Unknown')}: {round(float(item.get('score', 0))*100, 2)}%" for item in preds]
)
return final_label, top_score, scores_table
# ✅ Llama Guard Filter Function
def llama_guard_filter(text):
headers = {
"Authorization": "Bearer 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",
"Content-Type": "application/json"
}
data = {
"model": "meta-llama/llama-guard-3-8b",
"max_tokens": 25,
"messages": [
{
"role": "system",
"content": (
"You are a strict safety classifier. Respond with only one word: either 'safe' or 'unsafe'. "
"Do not add punctuation, explanation, or formatting."
)
},
{"role": "user", "content": text}
]
}
try:
response = requests.post("https://openrouter.ai/api/v1/chat/completions", headers=headers, json=data, timeout=10)
content = response.json()["choices"][0]["message"]["content"].strip().lower()
if "safe" in content:
return "safe"
else:
return "unsafe"
except Exception as e:
print("Llama Guard API Error:", e)
return "api_error"
# 🌐 Streamlit UI
st.set_page_config(page_title="Toxic Moderation System", layout="centered")
st.title("🛡️ Dual-Stage Toxic Moderation")
st.markdown("Moderate text and images using **Llama Guard** + **DistilBERT-LoRA**.\n\n- Stage 1: Hard Safety Filter (Llama Guard)\n- Stage 2: Fine Toxic Classifier (LoRA DistilBERT)")
text_input = st.text_area("✏️ Enter a text message", height=150)
uploaded_image = st.file_uploader("📷 Upload an image", type=["jpg", "jpeg", "png"])
image_caption = ""
if uploaded_image:
image = Image.open(uploaded_image)
st.image(image, caption="Uploaded Image", use_column_width=True)
with st.spinner("🔍 Generating caption with BLIP..."):
image_caption = caption_image(image)
st.success(f"📝 Caption: `{image_caption}`")
if st.button("🚀 Run Moderation"):
full_text = text_input + " [SEP] " + image_caption
with st.spinner("🛡️ Stage 1: Llama Guard..."):
safety = llama_guard_filter(full_text)
if safety == "unsafe":
st.error("❌ Llama Guard flagged this content as **UNSAFE**.\nModeration stopped.")
elif safety == "safe":
st.success("✅ Safe by Llama Guard. Proceeding to classifier...")
with st.spinner("🧠 Stage 2: DistilBERT Toxic Classifier..."):
label, score, scores = classify_toxicity(text_input, image_caption)
st.markdown(f"### 🔍 Prediction: `{label}` ({round(score*100, 2)}%)")
st.text("📊 Class Probabilities:\n" + scores)
else:
st.warning(f"Llama Guard API returned: {safety}. Proceed with caution.")
#All Thanks to Cellula for the opportunity