Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| import time | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from groq import Groq | |
| # -------------------------------------------------------------------------- | |
| # Configuration & Model Loading (Cached for efficiency) | |
| # -------------------------------------------------------------------------- | |
| CLASSIFIER_MODEL_NAME = "unitary/toxic-bert" | |
| LLM_MODEL_GROQ = "llama3-8b-8192" # Or mixtral-8x7b-32768 | |
| st.set_page_config(page_title="NeuroShield PoC", layout="wide") | |
| # Use Streamlit's caching for expensive operations like model loading | |
| def load_classifier_model(): | |
| """Loads the classifier model and tokenizer.""" | |
| print("Loading classifier model and tokenizer...") | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(CLASSIFIER_MODEL_NAME) | |
| model = AutoModelForSequenceClassification.from_pretrained(CLASSIFIER_MODEL_NAME) | |
| # Determine device (use CPU on free HF Spaces usually, unless GPU assigned) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| model.eval() | |
| print(f"Classifier model loaded on {device}.") | |
| # Get labels from model config | |
| model_labels = [model.config.id2label[i] for i in range(model.config.num_labels)] | |
| return tokenizer, model, device, model_labels | |
| except Exception as e: | |
| st.error(f"Error loading classifier model: {e}") | |
| print(f"Error loading classifier model: {e}") | |
| return None, None, None, [] | |
| def initialize_groq_client(): | |
| """Initializes the Groq client using API key from secrets.""" | |
| print("Initializing Groq client...") | |
| try: | |
| # Use st.secrets for Streamlit Community Cloud or os.environ for HF Spaces | |
| groq_api_key = os.environ.get('GROQ_API_KEY') | |
| if not groq_api_key: | |
| # Fallback for local testing if using secrets.toml | |
| try: | |
| groq_api_key = st.secrets["GROQ_API_KEY"] | |
| except Exception: | |
| st.warning("GROQ_API_KEY not found in environment variables or st.secrets.") | |
| return None | |
| if not groq_api_key: | |
| st.warning("Groq API Key not configured.") | |
| return None | |
| else: | |
| client = Groq(api_key=groq_api_key) | |
| print("Groq client initialized.") | |
| return client | |
| except Exception as e: | |
| st.error(f"Error initializing Groq client: {e}") | |
| print(f"Error initializing Groq client: {e}") | |
| return None | |
| # --- Load models and clients --- | |
| tokenizer, model, device, model_labels = load_classifier_model() | |
| groq_client = initialize_groq_client() | |
| # -------------------------------------------------------------------------- | |
| # Core Logic Functions | |
| # -------------------------------------------------------------------------- | |
| def classify_text(text, threshold=0.5): | |
| """Classifies input text using the loaded multi-label model.""" | |
| if model is None or tokenizer is None or device is None or not model_labels: | |
| st.error("Classifier model/tokenizer not loaded properly.") | |
| return None | |
| start_time = time.time() | |
| try: | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probabilities = torch.sigmoid(outputs.logits).squeeze().cpu().numpy() | |
| results = {} | |
| for i, label in enumerate(model_labels): | |
| if i < len(probabilities): | |
| prob = probabilities[i] | |
| if prob > threshold: | |
| results[label] = round(float(prob), 4) | |
| else: | |
| print(f"Warning: Index {i} out of bounds for probabilities") | |
| end_time = time.time() | |
| print(f"Classification took {end_time - start_time:.4f} seconds.") | |
| return results | |
| except Exception as e: | |
| st.error(f"An error occurred during classification: {e}") | |
| print(f"An error occurred during classification: {e}") | |
| return None | |
| def rewrite_text_groq(original_text, detected_labels_dict, persona="helpful assistant", tone="neutral"): | |
| """Rewrites the input text using the Groq API.""" | |
| if not groq_client: | |
| st.error("Groq client not initialized. Cannot perform rewrite.") | |
| return "Error: Groq client not initialized." | |
| # Construct the prompt (same logic as before) | |
| if not detected_labels_dict: | |
| detected_labels_list_str = "None relevant" | |
| prompt_template = f"""You are a {persona}. A user wrote: "{original_text}" | |
| Rewrite the message in a {tone} tone while keeping its essential meaning intact. Since no specific problematic categories were flagged, focus on ensuring the tone is appropriate and constructive.""" | |
| else: | |
| detected_labels_list_str = ", ".join(detected_labels_dict.keys()) | |
| prompt_template = f"""You are a {persona}. A user wrote: "{original_text}" | |
| Rewrite the message in a {tone} tone while keeping its essential meaning intact. | |
| Explain briefly why the original might be perceived as unsafe or negative, focusing on the potential impact rather than just listing labels. | |
| Ensure the rewritten message does NOT contain content related to the following categories: {detected_labels_list_str}. The goal is a safer, constructive alternative.""" | |
| print("\n--- Sending Request to Groq ---") | |
| print(f"Model: {LLM_MODEL_GROQ}") | |
| # print(f"Prompt:\n{prompt_template}\n" + "-"*20) # Avoid printing long prompts in logs | |
| start_time = time.time() | |
| try: | |
| chat_completion = groq_client.chat.completions.create( | |
| messages=[{"role": "user", "content": prompt_template}], | |
| model=LLM_MODEL_GROQ, | |
| temperature=0.6, | |
| max_tokens=350, # Increased slightly | |
| ) | |
| end_time = time.time() | |
| print(f"Groq response received in {end_time - start_time:.2f} seconds.") | |
| rewritten_content = chat_completion.choices[0].message.content.strip() | |
| return rewritten_content | |
| except Exception as e: | |
| st.error(f"Error interacting with Groq: {e}") | |
| print(f"Error interacting with Groq: {e}") | |
| return f"Error: Failed to get rewrite from Groq. {e}" | |
| def moderation_pipeline(input_text, classification_threshold=0.5): | |
| """Runs the full classification and rewrite pipeline.""" | |
| print(f"\n--- Running Streamlit Pipeline for input ---") | |
| pipeline_results = { | |
| "original_text": input_text, | |
| "detected_labels": {}, | |
| "rewrite_attempt": "(Not Attempted)", | |
| "error": None | |
| } | |
| # 1. Classification | |
| class_results = classify_text(input_text, threshold=classification_threshold) | |
| if class_results is None: | |
| pipeline_results["error"] = "Classification failed. Check logs." | |
| return pipeline_results | |
| pipeline_results["detected_labels"] = class_results | |
| print(f"Classification Results: {class_results if class_results else 'None above threshold'}") | |
| # 2. Rewrite (using Groq) | |
| rewrite = rewrite_text_groq(input_text, class_results, persona="content moderator", tone="neutral and constructive") | |
| pipeline_results["rewrite_attempt"] = rewrite | |
| print("--- Pipeline Finished ---") | |
| return pipeline_results | |
| # -------------------------------------------------------------------------- | |
| # Streamlit UI Layout | |
| # -------------------------------------------------------------------------- | |
| st.title("NeuroShield Proof-of-Concept") | |
| st.markdown("A demonstration using a pre-trained toxicity classifier (`unitary/toxic-bert`) and an LLM rewrite suggestion via Groq API (`llama3-8b`). Enter text below and click 'Moderate'.") | |
| st.markdown("---") # Separator | |
| # Initialize session state to hold results | |
| if 'pipeline_results' not in st.session_state: | |
| st.session_state.pipeline_results = None | |
| # Input Text Area | |
| user_input = st.text_area("Enter text to moderate:", height=100, key="user_input_area") | |
| # Moderate Button | |
| if st.button("Moderate Text", key="moderate_button"): | |
| if user_input: | |
| # Show a spinner while processing | |
| with st.spinner("Moderating..."): | |
| # Check if prerequisites are loaded | |
| if model and tokenizer and groq_client: | |
| results = moderation_pipeline(user_input) | |
| st.session_state.pipeline_results = results # Store results in session state | |
| else: | |
| st.error("Models or API client failed to load. Cannot moderate.") | |
| st.session_state.pipeline_results = {"error": "Models or API client failed to load."} | |
| else: | |
| st.warning("Please enter some text to moderate.") | |
| st.session_state.pipeline_results = None # Clear results if input is empty | |
| # Display Results (using columns for better layout) | |
| if st.session_state.pipeline_results: | |
| results = st.session_state.pipeline_results | |
| st.markdown("---") # Separator | |
| st.subheader("Moderation Results") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.metric(label="Input Text Status", value="Processed") | |
| st.markdown("**Detected Labels & Scores**") | |
| if results.get("error"): | |
| st.error(f"Pipeline Error: {results['error']}") | |
| elif results.get("detected_labels"): | |
| st.json(results["detected_labels"]) | |
| else: | |
| st.success("No problematic labels detected above threshold.") | |
| with col2: | |
| st.markdown("**Rewrite Suggestion**") | |
| rewrite_text = results.get("rewrite_attempt", "Rewrite not generated.") | |
| # Use a text area to display the rewrite, making it copyable | |
| st.text_area("Suggested Rewrite:", value=rewrite_text, height=250, disabled=True, key="rewrite_output_area") | |
| # Optional: Add footer or more info | |
| st.markdown("---") | |
| st.caption("Powered by Hugging Face Transformers and Groq API.") |