import streamlit as st import os import time import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification from groq import Groq # -------------------------------------------------------------------------- # Configuration & Model Loading (Cached for efficiency) # -------------------------------------------------------------------------- CLASSIFIER_MODEL_NAME = "unitary/toxic-bert" LLM_MODEL_GROQ = "llama3-8b-8192" # Or mixtral-8x7b-32768 st.set_page_config(page_title="NeuroShield PoC", layout="wide") # Use Streamlit's caching for expensive operations like model loading @st.cache_resource def load_classifier_model(): """Loads the classifier model and tokenizer.""" print("Loading classifier model and tokenizer...") try: tokenizer = AutoTokenizer.from_pretrained(CLASSIFIER_MODEL_NAME) model = AutoModelForSequenceClassification.from_pretrained(CLASSIFIER_MODEL_NAME) # Determine device (use CPU on free HF Spaces usually, unless GPU assigned) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) model.eval() print(f"Classifier model loaded on {device}.") # Get labels from model config model_labels = [model.config.id2label[i] for i in range(model.config.num_labels)] return tokenizer, model, device, model_labels except Exception as e: st.error(f"Error loading classifier model: {e}") print(f"Error loading classifier model: {e}") return None, None, None, [] @st.cache_resource def initialize_groq_client(): """Initializes the Groq client using API key from secrets.""" print("Initializing Groq client...") try: # Use st.secrets for Streamlit Community Cloud or os.environ for HF Spaces groq_api_key = os.environ.get('GROQ_API_KEY') if not groq_api_key: # Fallback for local testing if using secrets.toml try: groq_api_key = st.secrets["GROQ_API_KEY"] except Exception: st.warning("GROQ_API_KEY not found in environment variables or st.secrets.") return None if not groq_api_key: st.warning("Groq API Key not configured.") return None else: client = Groq(api_key=groq_api_key) print("Groq client initialized.") return client except Exception as e: st.error(f"Error initializing Groq client: {e}") print(f"Error initializing Groq client: {e}") return None # --- Load models and clients --- tokenizer, model, device, model_labels = load_classifier_model() groq_client = initialize_groq_client() # -------------------------------------------------------------------------- # Core Logic Functions # -------------------------------------------------------------------------- def classify_text(text, threshold=0.5): """Classifies input text using the loaded multi-label model.""" if model is None or tokenizer is None or device is None or not model_labels: st.error("Classifier model/tokenizer not loaded properly.") return None start_time = time.time() try: inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs) probabilities = torch.sigmoid(outputs.logits).squeeze().cpu().numpy() results = {} for i, label in enumerate(model_labels): if i < len(probabilities): prob = probabilities[i] if prob > threshold: results[label] = round(float(prob), 4) else: print(f"Warning: Index {i} out of bounds for probabilities") end_time = time.time() print(f"Classification took {end_time - start_time:.4f} seconds.") return results except Exception as e: st.error(f"An error occurred during classification: {e}") print(f"An error occurred during classification: {e}") return None def rewrite_text_groq(original_text, detected_labels_dict, persona="helpful assistant", tone="neutral"): """Rewrites the input text using the Groq API.""" if not groq_client: st.error("Groq client not initialized. Cannot perform rewrite.") return "Error: Groq client not initialized." # Construct the prompt (same logic as before) if not detected_labels_dict: detected_labels_list_str = "None relevant" prompt_template = f"""You are a {persona}. A user wrote: "{original_text}" Rewrite the message in a {tone} tone while keeping its essential meaning intact. Since no specific problematic categories were flagged, focus on ensuring the tone is appropriate and constructive.""" else: detected_labels_list_str = ", ".join(detected_labels_dict.keys()) prompt_template = f"""You are a {persona}. A user wrote: "{original_text}" Rewrite the message in a {tone} tone while keeping its essential meaning intact. Explain briefly why the original might be perceived as unsafe or negative, focusing on the potential impact rather than just listing labels. Ensure the rewritten message does NOT contain content related to the following categories: {detected_labels_list_str}. The goal is a safer, constructive alternative.""" print("\n--- Sending Request to Groq ---") print(f"Model: {LLM_MODEL_GROQ}") # print(f"Prompt:\n{prompt_template}\n" + "-"*20) # Avoid printing long prompts in logs start_time = time.time() try: chat_completion = groq_client.chat.completions.create( messages=[{"role": "user", "content": prompt_template}], model=LLM_MODEL_GROQ, temperature=0.6, max_tokens=350, # Increased slightly ) end_time = time.time() print(f"Groq response received in {end_time - start_time:.2f} seconds.") rewritten_content = chat_completion.choices[0].message.content.strip() return rewritten_content except Exception as e: st.error(f"Error interacting with Groq: {e}") print(f"Error interacting with Groq: {e}") return f"Error: Failed to get rewrite from Groq. {e}" def moderation_pipeline(input_text, classification_threshold=0.5): """Runs the full classification and rewrite pipeline.""" print(f"\n--- Running Streamlit Pipeline for input ---") pipeline_results = { "original_text": input_text, "detected_labels": {}, "rewrite_attempt": "(Not Attempted)", "error": None } # 1. Classification class_results = classify_text(input_text, threshold=classification_threshold) if class_results is None: pipeline_results["error"] = "Classification failed. Check logs." return pipeline_results pipeline_results["detected_labels"] = class_results print(f"Classification Results: {class_results if class_results else 'None above threshold'}") # 2. Rewrite (using Groq) rewrite = rewrite_text_groq(input_text, class_results, persona="content moderator", tone="neutral and constructive") pipeline_results["rewrite_attempt"] = rewrite print("--- Pipeline Finished ---") return pipeline_results # -------------------------------------------------------------------------- # Streamlit UI Layout # -------------------------------------------------------------------------- st.title("NeuroShield Proof-of-Concept") st.markdown("A demonstration using a pre-trained toxicity classifier (`unitary/toxic-bert`) and an LLM rewrite suggestion via Groq API (`llama3-8b`). Enter text below and click 'Moderate'.") st.markdown("---") # Separator # Initialize session state to hold results if 'pipeline_results' not in st.session_state: st.session_state.pipeline_results = None # Input Text Area user_input = st.text_area("Enter text to moderate:", height=100, key="user_input_area") # Moderate Button if st.button("Moderate Text", key="moderate_button"): if user_input: # Show a spinner while processing with st.spinner("Moderating..."): # Check if prerequisites are loaded if model and tokenizer and groq_client: results = moderation_pipeline(user_input) st.session_state.pipeline_results = results # Store results in session state else: st.error("Models or API client failed to load. Cannot moderate.") st.session_state.pipeline_results = {"error": "Models or API client failed to load."} else: st.warning("Please enter some text to moderate.") st.session_state.pipeline_results = None # Clear results if input is empty # Display Results (using columns for better layout) if st.session_state.pipeline_results: results = st.session_state.pipeline_results st.markdown("---") # Separator st.subheader("Moderation Results") col1, col2 = st.columns(2) with col1: st.metric(label="Input Text Status", value="Processed") st.markdown("**Detected Labels & Scores**") if results.get("error"): st.error(f"Pipeline Error: {results['error']}") elif results.get("detected_labels"): st.json(results["detected_labels"]) else: st.success("No problematic labels detected above threshold.") with col2: st.markdown("**Rewrite Suggestion**") rewrite_text = results.get("rewrite_attempt", "Rewrite not generated.") # Use a text area to display the rewrite, making it copyable st.text_area("Suggested Rewrite:", value=rewrite_text, height=250, disabled=True, key="rewrite_output_area") # Optional: Add footer or more info st.markdown("---") st.caption("Powered by Hugging Face Transformers and Groq API.")