NeuroShieldApp / app.py
PradAgrawal's picture
Upload 4 files
68635d7 verified
import streamlit as st
import os
import time
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from groq import Groq
# --------------------------------------------------------------------------
# Configuration & Model Loading (Cached for efficiency)
# --------------------------------------------------------------------------
CLASSIFIER_MODEL_NAME = "unitary/toxic-bert"
LLM_MODEL_GROQ = "llama3-8b-8192" # Or mixtral-8x7b-32768
st.set_page_config(page_title="NeuroShield PoC", layout="wide")
# Use Streamlit's caching for expensive operations like model loading
@st.cache_resource
def load_classifier_model():
"""Loads the classifier model and tokenizer."""
print("Loading classifier model and tokenizer...")
try:
tokenizer = AutoTokenizer.from_pretrained(CLASSIFIER_MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(CLASSIFIER_MODEL_NAME)
# Determine device (use CPU on free HF Spaces usually, unless GPU assigned)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
print(f"Classifier model loaded on {device}.")
# Get labels from model config
model_labels = [model.config.id2label[i] for i in range(model.config.num_labels)]
return tokenizer, model, device, model_labels
except Exception as e:
st.error(f"Error loading classifier model: {e}")
print(f"Error loading classifier model: {e}")
return None, None, None, []
@st.cache_resource
def initialize_groq_client():
"""Initializes the Groq client using API key from secrets."""
print("Initializing Groq client...")
try:
# Use st.secrets for Streamlit Community Cloud or os.environ for HF Spaces
groq_api_key = os.environ.get('GROQ_API_KEY')
if not groq_api_key:
# Fallback for local testing if using secrets.toml
try:
groq_api_key = st.secrets["GROQ_API_KEY"]
except Exception:
st.warning("GROQ_API_KEY not found in environment variables or st.secrets.")
return None
if not groq_api_key:
st.warning("Groq API Key not configured.")
return None
else:
client = Groq(api_key=groq_api_key)
print("Groq client initialized.")
return client
except Exception as e:
st.error(f"Error initializing Groq client: {e}")
print(f"Error initializing Groq client: {e}")
return None
# --- Load models and clients ---
tokenizer, model, device, model_labels = load_classifier_model()
groq_client = initialize_groq_client()
# --------------------------------------------------------------------------
# Core Logic Functions
# --------------------------------------------------------------------------
def classify_text(text, threshold=0.5):
"""Classifies input text using the loaded multi-label model."""
if model is None or tokenizer is None or device is None or not model_labels:
st.error("Classifier model/tokenizer not loaded properly.")
return None
start_time = time.time()
try:
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
probabilities = torch.sigmoid(outputs.logits).squeeze().cpu().numpy()
results = {}
for i, label in enumerate(model_labels):
if i < len(probabilities):
prob = probabilities[i]
if prob > threshold:
results[label] = round(float(prob), 4)
else:
print(f"Warning: Index {i} out of bounds for probabilities")
end_time = time.time()
print(f"Classification took {end_time - start_time:.4f} seconds.")
return results
except Exception as e:
st.error(f"An error occurred during classification: {e}")
print(f"An error occurred during classification: {e}")
return None
def rewrite_text_groq(original_text, detected_labels_dict, persona="helpful assistant", tone="neutral"):
"""Rewrites the input text using the Groq API."""
if not groq_client:
st.error("Groq client not initialized. Cannot perform rewrite.")
return "Error: Groq client not initialized."
# Construct the prompt (same logic as before)
if not detected_labels_dict:
detected_labels_list_str = "None relevant"
prompt_template = f"""You are a {persona}. A user wrote: "{original_text}"
Rewrite the message in a {tone} tone while keeping its essential meaning intact. Since no specific problematic categories were flagged, focus on ensuring the tone is appropriate and constructive."""
else:
detected_labels_list_str = ", ".join(detected_labels_dict.keys())
prompt_template = f"""You are a {persona}. A user wrote: "{original_text}"
Rewrite the message in a {tone} tone while keeping its essential meaning intact.
Explain briefly why the original might be perceived as unsafe or negative, focusing on the potential impact rather than just listing labels.
Ensure the rewritten message does NOT contain content related to the following categories: {detected_labels_list_str}. The goal is a safer, constructive alternative."""
print("\n--- Sending Request to Groq ---")
print(f"Model: {LLM_MODEL_GROQ}")
# print(f"Prompt:\n{prompt_template}\n" + "-"*20) # Avoid printing long prompts in logs
start_time = time.time()
try:
chat_completion = groq_client.chat.completions.create(
messages=[{"role": "user", "content": prompt_template}],
model=LLM_MODEL_GROQ,
temperature=0.6,
max_tokens=350, # Increased slightly
)
end_time = time.time()
print(f"Groq response received in {end_time - start_time:.2f} seconds.")
rewritten_content = chat_completion.choices[0].message.content.strip()
return rewritten_content
except Exception as e:
st.error(f"Error interacting with Groq: {e}")
print(f"Error interacting with Groq: {e}")
return f"Error: Failed to get rewrite from Groq. {e}"
def moderation_pipeline(input_text, classification_threshold=0.5):
"""Runs the full classification and rewrite pipeline."""
print(f"\n--- Running Streamlit Pipeline for input ---")
pipeline_results = {
"original_text": input_text,
"detected_labels": {},
"rewrite_attempt": "(Not Attempted)",
"error": None
}
# 1. Classification
class_results = classify_text(input_text, threshold=classification_threshold)
if class_results is None:
pipeline_results["error"] = "Classification failed. Check logs."
return pipeline_results
pipeline_results["detected_labels"] = class_results
print(f"Classification Results: {class_results if class_results else 'None above threshold'}")
# 2. Rewrite (using Groq)
rewrite = rewrite_text_groq(input_text, class_results, persona="content moderator", tone="neutral and constructive")
pipeline_results["rewrite_attempt"] = rewrite
print("--- Pipeline Finished ---")
return pipeline_results
# --------------------------------------------------------------------------
# Streamlit UI Layout
# --------------------------------------------------------------------------
st.title("NeuroShield Proof-of-Concept")
st.markdown("A demonstration using a pre-trained toxicity classifier (`unitary/toxic-bert`) and an LLM rewrite suggestion via Groq API (`llama3-8b`). Enter text below and click 'Moderate'.")
st.markdown("---") # Separator
# Initialize session state to hold results
if 'pipeline_results' not in st.session_state:
st.session_state.pipeline_results = None
# Input Text Area
user_input = st.text_area("Enter text to moderate:", height=100, key="user_input_area")
# Moderate Button
if st.button("Moderate Text", key="moderate_button"):
if user_input:
# Show a spinner while processing
with st.spinner("Moderating..."):
# Check if prerequisites are loaded
if model and tokenizer and groq_client:
results = moderation_pipeline(user_input)
st.session_state.pipeline_results = results # Store results in session state
else:
st.error("Models or API client failed to load. Cannot moderate.")
st.session_state.pipeline_results = {"error": "Models or API client failed to load."}
else:
st.warning("Please enter some text to moderate.")
st.session_state.pipeline_results = None # Clear results if input is empty
# Display Results (using columns for better layout)
if st.session_state.pipeline_results:
results = st.session_state.pipeline_results
st.markdown("---") # Separator
st.subheader("Moderation Results")
col1, col2 = st.columns(2)
with col1:
st.metric(label="Input Text Status", value="Processed")
st.markdown("**Detected Labels & Scores**")
if results.get("error"):
st.error(f"Pipeline Error: {results['error']}")
elif results.get("detected_labels"):
st.json(results["detected_labels"])
else:
st.success("No problematic labels detected above threshold.")
with col2:
st.markdown("**Rewrite Suggestion**")
rewrite_text = results.get("rewrite_attempt", "Rewrite not generated.")
# Use a text area to display the rewrite, making it copyable
st.text_area("Suggested Rewrite:", value=rewrite_text, height=250, disabled=True, key="rewrite_output_area")
# Optional: Add footer or more info
st.markdown("---")
st.caption("Powered by Hugging Face Transformers and Groq API.")