File size: 9,946 Bytes
68635d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
import streamlit as st
import os
import time
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from groq import Groq

# --------------------------------------------------------------------------
# Configuration & Model Loading (Cached for efficiency)
# --------------------------------------------------------------------------
CLASSIFIER_MODEL_NAME = "unitary/toxic-bert"
LLM_MODEL_GROQ = "llama3-8b-8192"  # Or mixtral-8x7b-32768

st.set_page_config(page_title="NeuroShield PoC", layout="wide")

# Use Streamlit's caching for expensive operations like model loading
@st.cache_resource
def load_classifier_model():
    """Loads the classifier model and tokenizer."""
    print("Loading classifier model and tokenizer...")
    try:
        tokenizer = AutoTokenizer.from_pretrained(CLASSIFIER_MODEL_NAME)
        model = AutoModelForSequenceClassification.from_pretrained(CLASSIFIER_MODEL_NAME)
        # Determine device (use CPU on free HF Spaces usually, unless GPU assigned)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model.to(device)
        model.eval()
        print(f"Classifier model loaded on {device}.")
        # Get labels from model config
        model_labels = [model.config.id2label[i] for i in range(model.config.num_labels)]
        return tokenizer, model, device, model_labels
    except Exception as e:
        st.error(f"Error loading classifier model: {e}")
        print(f"Error loading classifier model: {e}")
        return None, None, None, []

@st.cache_resource
def initialize_groq_client():
    """Initializes the Groq client using API key from secrets."""
    print("Initializing Groq client...")
    try:
        # Use st.secrets for Streamlit Community Cloud or os.environ for HF Spaces
        groq_api_key = os.environ.get('GROQ_API_KEY')
        if not groq_api_key:
            # Fallback for local testing if using secrets.toml
            try:
                groq_api_key = st.secrets["GROQ_API_KEY"]
            except Exception:
                 st.warning("GROQ_API_KEY not found in environment variables or st.secrets.")
                 return None

        if not groq_api_key:
             st.warning("Groq API Key not configured.")
             return None
        else:
             client = Groq(api_key=groq_api_key)
             print("Groq client initialized.")
             return client
    except Exception as e:
        st.error(f"Error initializing Groq client: {e}")
        print(f"Error initializing Groq client: {e}")
        return None

# --- Load models and clients ---
tokenizer, model, device, model_labels = load_classifier_model()
groq_client = initialize_groq_client()

# --------------------------------------------------------------------------
# Core Logic Functions
# --------------------------------------------------------------------------
def classify_text(text, threshold=0.5):
    """Classifies input text using the loaded multi-label model."""
    if model is None or tokenizer is None or device is None or not model_labels:
        st.error("Classifier model/tokenizer not loaded properly.")
        return None

    start_time = time.time()
    try:
        inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = model(**inputs)

        probabilities = torch.sigmoid(outputs.logits).squeeze().cpu().numpy()
        results = {}
        for i, label in enumerate(model_labels):
            if i < len(probabilities):
                prob = probabilities[i]
                if prob > threshold:
                    results[label] = round(float(prob), 4)
            else:
                print(f"Warning: Index {i} out of bounds for probabilities")

        end_time = time.time()
        print(f"Classification took {end_time - start_time:.4f} seconds.")
        return results

    except Exception as e:
        st.error(f"An error occurred during classification: {e}")
        print(f"An error occurred during classification: {e}")
        return None


def rewrite_text_groq(original_text, detected_labels_dict, persona="helpful assistant", tone="neutral"):
    """Rewrites the input text using the Groq API."""
    if not groq_client:
        st.error("Groq client not initialized. Cannot perform rewrite.")
        return "Error: Groq client not initialized."

    # Construct the prompt (same logic as before)
    if not detected_labels_dict:
        detected_labels_list_str = "None relevant"
        prompt_template = f"""You are a {persona}. A user wrote: "{original_text}"

Rewrite the message in a {tone} tone while keeping its essential meaning intact. Since no specific problematic categories were flagged, focus on ensuring the tone is appropriate and constructive."""
    else:
        detected_labels_list_str = ", ".join(detected_labels_dict.keys())
        prompt_template = f"""You are a {persona}. A user wrote: "{original_text}"

Rewrite the message in a {tone} tone while keeping its essential meaning intact.

Explain briefly why the original might be perceived as unsafe or negative, focusing on the potential impact rather than just listing labels.

Ensure the rewritten message does NOT contain content related to the following categories: {detected_labels_list_str}. The goal is a safer, constructive alternative."""

    print("\n--- Sending Request to Groq ---")
    print(f"Model: {LLM_MODEL_GROQ}")
    # print(f"Prompt:\n{prompt_template}\n" + "-"*20) # Avoid printing long prompts in logs

    start_time = time.time()
    try:
        chat_completion = groq_client.chat.completions.create(
            messages=[{"role": "user", "content": prompt_template}],
            model=LLM_MODEL_GROQ,
            temperature=0.6,
            max_tokens=350, # Increased slightly
        )
        end_time = time.time()
        print(f"Groq response received in {end_time - start_time:.2f} seconds.")
        rewritten_content = chat_completion.choices[0].message.content.strip()
        return rewritten_content

    except Exception as e:
        st.error(f"Error interacting with Groq: {e}")
        print(f"Error interacting with Groq: {e}")
        return f"Error: Failed to get rewrite from Groq. {e}"


def moderation_pipeline(input_text, classification_threshold=0.5):
    """Runs the full classification and rewrite pipeline."""
    print(f"\n--- Running Streamlit Pipeline for input ---")
    pipeline_results = {
        "original_text": input_text,
        "detected_labels": {},
        "rewrite_attempt": "(Not Attempted)",
        "error": None
    }

    # 1. Classification
    class_results = classify_text(input_text, threshold=classification_threshold)
    if class_results is None:
        pipeline_results["error"] = "Classification failed. Check logs."
        return pipeline_results
    pipeline_results["detected_labels"] = class_results
    print(f"Classification Results: {class_results if class_results else 'None above threshold'}")

    # 2. Rewrite (using Groq)
    rewrite = rewrite_text_groq(input_text, class_results, persona="content moderator", tone="neutral and constructive")
    pipeline_results["rewrite_attempt"] = rewrite

    print("--- Pipeline Finished ---")
    return pipeline_results

# --------------------------------------------------------------------------
# Streamlit UI Layout
# --------------------------------------------------------------------------

st.title("NeuroShield Proof-of-Concept")
st.markdown("A demonstration using a pre-trained toxicity classifier (`unitary/toxic-bert`) and an LLM rewrite suggestion via Groq API (`llama3-8b`). Enter text below and click 'Moderate'.")
st.markdown("---") # Separator

# Initialize session state to hold results
if 'pipeline_results' not in st.session_state:
    st.session_state.pipeline_results = None

# Input Text Area
user_input = st.text_area("Enter text to moderate:", height=100, key="user_input_area")

# Moderate Button
if st.button("Moderate Text", key="moderate_button"):
    if user_input:
        # Show a spinner while processing
        with st.spinner("Moderating..."):
            # Check if prerequisites are loaded
            if model and tokenizer and groq_client:
                results = moderation_pipeline(user_input)
                st.session_state.pipeline_results = results # Store results in session state
            else:
                 st.error("Models or API client failed to load. Cannot moderate.")
                 st.session_state.pipeline_results = {"error": "Models or API client failed to load."}
    else:
        st.warning("Please enter some text to moderate.")
        st.session_state.pipeline_results = None # Clear results if input is empty

# Display Results (using columns for better layout)
if st.session_state.pipeline_results:
    results = st.session_state.pipeline_results
    st.markdown("---") # Separator
    st.subheader("Moderation Results")

    col1, col2 = st.columns(2)

    with col1:
        st.metric(label="Input Text Status", value="Processed")
        st.markdown("**Detected Labels & Scores**")
        if results.get("error"):
             st.error(f"Pipeline Error: {results['error']}")
        elif results.get("detected_labels"):
             st.json(results["detected_labels"])
        else:
             st.success("No problematic labels detected above threshold.")

    with col2:
        st.markdown("**Rewrite Suggestion**")
        rewrite_text = results.get("rewrite_attempt", "Rewrite not generated.")
        # Use a text area to display the rewrite, making it copyable
        st.text_area("Suggested Rewrite:", value=rewrite_text, height=250, disabled=True, key="rewrite_output_area")

# Optional: Add footer or more info
st.markdown("---")
st.caption("Powered by Hugging Face Transformers and Groq API.")