sidchak commited on
Commit
7abc132
·
1 Parent(s): 0bfadbf

Implement async model initialization: profanity_detector.py

Browse files
Files changed (1) hide show
  1. profanity_detector.py +964 -0
profanity_detector.py ADDED
@@ -0,0 +1,964 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModelForSeq2SeqLM
3
+ from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
4
+ import whisper
5
+ import gradio as gr
6
+ import re
7
+ import pandas as pd
8
+ import numpy as np
9
+ import os
10
+ import time
11
+ import logging
12
+ import threading
13
+ import queue
14
+ from scipy.io.wavfile import write as write_wav
15
+ from html import escape
16
+ import traceback
17
+ import spaces # Required for Hugging Face ZeroGPU compatibility
18
+
19
+ # Configure logging
20
+ logging.basicConfig(
21
+ level=logging.INFO,
22
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
23
+ handlers=[logging.StreamHandler()]
24
+ )
25
+ logger = logging.getLogger('profanity_detector')
26
+
27
+ # Detect if we're running in a ZeroGPU environment
28
+ IS_ZEROGPU = os.environ.get("SPACE_RUNTIME_STATELESS", "0") == "1"
29
+ if os.environ.get("SPACES_ZERO_GPU") is not None:
30
+ IS_ZEROGPU = True
31
+
32
+ # Define device strategy that works in both environments
33
+ if IS_ZEROGPU:
34
+ # In ZeroGPU: always initialize on CPU, will use GPU only in @spaces.GPU functions
35
+ device = torch.device("cpu")
36
+ logger.info("ZeroGPU environment detected. Using CPU for initial loading.")
37
+ else:
38
+ # For local runs: use CUDA if available
39
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
+ logger.info(f"Local environment. Using device: {device}")
41
+
42
+ # Global variables for models
43
+ profanity_model = None
44
+ profanity_tokenizer = None
45
+ t5_model = None
46
+ t5_tokenizer = None
47
+ whisper_model = None
48
+ tts_processor = None
49
+ tts_model = None
50
+ vocoder = None
51
+ models_loaded = False
52
+
53
+ # Default speaker embeddings for TTS
54
+ speaker_embeddings = None
55
+
56
+ # Queue for real-time audio processing
57
+ audio_queue = queue.Queue()
58
+ processing_active = False
59
+
60
+ # Model loading with int8 quantization
61
+ def load_models():
62
+ global profanity_model, profanity_tokenizer, t5_model, t5_tokenizer, whisper_model
63
+ global tts_processor, tts_model, vocoder, speaker_embeddings, models_loaded
64
+
65
+ try:
66
+ logger.info("Loading profanity detection model...")
67
+ PROFANITY_MODEL = "parsawar/profanity_model_3.1"
68
+ profanity_tokenizer = AutoTokenizer.from_pretrained(PROFANITY_MODEL)
69
+
70
+ # Load model without moving to CUDA directly
71
+ profanity_model = AutoModelForSequenceClassification.from_pretrained(
72
+ PROFANITY_MODEL,
73
+ device_map=None, # Stay on CPU for now
74
+ low_cpu_mem_usage=True
75
+ )
76
+
77
+ # Only move to device if NOT in ZeroGPU mode
78
+ if not IS_ZEROGPU and torch.cuda.is_available():
79
+ profanity_model = profanity_model.to(device)
80
+ try:
81
+ profanity_model = profanity_model.half()
82
+ logger.info("Successfully converted profanity model to half precision")
83
+ except Exception as e:
84
+ logger.warning(f"Could not convert to half precision: {str(e)}")
85
+
86
+ logger.info("Loading detoxification model...")
87
+ T5_MODEL = "s-nlp/t5-paranmt-detox"
88
+ t5_tokenizer = AutoTokenizer.from_pretrained(T5_MODEL)
89
+
90
+ t5_model = AutoModelForSeq2SeqLM.from_pretrained(
91
+ T5_MODEL,
92
+ device_map=None, # Stay on CPU for now
93
+ low_cpu_mem_usage=True
94
+ )
95
+
96
+ # Only move to device if NOT in ZeroGPU mode
97
+ if not IS_ZEROGPU and torch.cuda.is_available():
98
+ t5_model = t5_model.to(device)
99
+ try:
100
+ t5_model = t5_model.half()
101
+ logger.info("Successfully converted T5 model to half precision")
102
+ except Exception as e:
103
+ logger.warning(f"Could not convert to half precision: {str(e)}")
104
+
105
+ logger.info("Loading Whisper speech-to-text model...")
106
+ # Always load on CPU in ZeroGPU mode
107
+ #whisper_model = whisper.load_model("medium" if IS_ZEROGPU else "large", device="cpu")
108
+ whisper_model = whisper.load_model("large-v2", device="cpu")
109
+
110
+ # Only move to device if NOT in ZeroGPU mode
111
+ if not IS_ZEROGPU and torch.cuda.is_available():
112
+ whisper_model = whisper_model.to(device)
113
+
114
+ logger.info("Loading Text-to-Speech model...")
115
+ TTS_MODEL = "microsoft/speecht5_tts"
116
+ tts_processor = SpeechT5Processor.from_pretrained(TTS_MODEL)
117
+
118
+ tts_model = SpeechT5ForTextToSpeech.from_pretrained(
119
+ TTS_MODEL,
120
+ device_map=None, # Stay on CPU for now
121
+ low_cpu_mem_usage=True
122
+ )
123
+
124
+ vocoder = SpeechT5HifiGan.from_pretrained(
125
+ "microsoft/speecht5_hifigan",
126
+ device_map=None, # Stay on CPU for now
127
+ low_cpu_mem_usage=True
128
+ )
129
+
130
+ # Only move to device if NOT in ZeroGPU mode
131
+ if not IS_ZEROGPU and torch.cuda.is_available():
132
+ tts_model = tts_model.to(device)
133
+ vocoder = vocoder.to(device)
134
+
135
+ # Speaker embeddings - always on CPU for ZeroGPU
136
+ speaker_embeddings = torch.zeros((1, 512))
137
+ # Only move to device if NOT in ZeroGPU mode
138
+ if not IS_ZEROGPU and torch.cuda.is_available():
139
+ speaker_embeddings = speaker_embeddings.to(device)
140
+
141
+ models_loaded = True
142
+ logger.info("All models loaded successfully.")
143
+
144
+ return "Models loaded successfully."
145
+ except Exception as e:
146
+ error_msg = f"Error loading models: {str(e)}\n{traceback.format_exc()}"
147
+ logger.error(error_msg)
148
+ return error_msg
149
+
150
+ # ZeroGPU decorator: Requests GPU resources when function is called and releases them when completed.
151
+ # This enables efficient GPU sharing in Hugging Face Spaces while having no effect in local environments.
152
+ @spaces.GPU
153
+ def detect_profanity(text: str, threshold: float = 0.5):
154
+ """
155
+ Detect profanity in text with adjustable threshold
156
+
157
+ Args:
158
+ text: The input text to analyze
159
+ threshold: Profanity detection threshold (0.0-1.0)
160
+
161
+ Returns:
162
+ Dictionary with analysis results
163
+ """
164
+ if not models_loaded:
165
+ return {"error": "Models not loaded yet. Please wait."}
166
+
167
+ try:
168
+ # Detect profanity and score
169
+ inputs = profanity_tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
170
+
171
+ # In ZeroGPU, move to GPU here inside the spaces.GPU function
172
+ # For local environments, it might already be on the correct device
173
+ current_device = device
174
+ if IS_ZEROGPU and torch.cuda.is_available():
175
+ current_device = torch.device("cuda")
176
+ inputs = inputs.to(current_device)
177
+ # Only in ZeroGPU mode, we need to move the model to GPU inside the function
178
+ profanity_model.to(current_device)
179
+ elif torch.cuda.is_available(): # Local environment with CUDA
180
+ inputs = inputs.to(current_device)
181
+
182
+ with torch.no_grad():
183
+ outputs = profanity_model(**inputs).logits
184
+ score = torch.nn.functional.softmax(outputs, dim=1)[0][1].item()
185
+
186
+ # Identify specific profane words
187
+ words = re.findall(r'\b\w+\b', text)
188
+ profane_words = []
189
+ word_scores = {}
190
+
191
+ if score > threshold:
192
+ for word in words:
193
+ if len(word) < 2: # Skip very short words
194
+ continue
195
+
196
+ word_inputs = profanity_tokenizer(word, return_tensors="pt", truncation=True, max_length=512)
197
+ if torch.cuda.is_available():
198
+ word_inputs = word_inputs.to(current_device)
199
+
200
+ with torch.no_grad():
201
+ word_outputs = profanity_model(**word_inputs).logits
202
+ word_score = torch.nn.functional.softmax(word_outputs, dim=1)[0][1].item()
203
+ word_scores[word] = word_score
204
+
205
+ if word_score > threshold:
206
+ profane_words.append(word.lower())
207
+
208
+ # Move model back to CPU if in ZeroGPU mode - to free GPU memory
209
+ if IS_ZEROGPU and torch.cuda.is_available():
210
+ profanity_model.to(torch.device("cpu"))
211
+
212
+ # Create highlighted version of the text
213
+ highlighted_text = create_highlighted_text(text, profane_words)
214
+
215
+ return {
216
+ "text": text,
217
+ "score": score,
218
+ "profanity": score > threshold,
219
+ "profane_words": profane_words,
220
+ "highlighted_text": highlighted_text,
221
+ "word_scores": word_scores
222
+ }
223
+ except Exception as e:
224
+ error_msg = f"Error in profanity detection: {str(e)}"
225
+ logger.error(error_msg)
226
+ # Make sure model is on CPU if in ZeroGPU mode - to free GPU memory
227
+ if IS_ZEROGPU and torch.cuda.is_available():
228
+ try:
229
+ profanity_model.to(torch.device("cpu"))
230
+ except:
231
+ pass
232
+ return {"error": error_msg, "text": text, "score": 0, "profanity": False}
233
+
234
+ def create_highlighted_text(text, profane_words):
235
+ """
236
+ Create HTML-formatted text with profane words highlighted
237
+ """
238
+ if not profane_words:
239
+ return escape(text)
240
+
241
+ # Create a regex pattern matching any of the profane words (case insensitive)
242
+ pattern = r'\b(' + '|'.join(re.escape(word) for word in profane_words) + r')\b'
243
+
244
+ # Replace occurrences with highlighted versions
245
+ def highlight_match(match):
246
+ return f'<span style="background-color: rgba(255, 0, 0, 0.3); padding: 0px 2px; border-radius: 3px;">{match.group(0)}</span>'
247
+
248
+ highlighted = re.sub(pattern, highlight_match, text, flags=re.IGNORECASE)
249
+ return highlighted
250
+
251
+ @spaces.GPU
252
+ def rephrase_profanity(text):
253
+ """
254
+ Rephrase text containing profanity
255
+ """
256
+ if not models_loaded:
257
+ return "Models not loaded yet. Please wait."
258
+
259
+ try:
260
+ # Rephrase using the detoxification model
261
+ inputs = t5_tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
262
+
263
+ # In ZeroGPU, move to GPU here inside the spaces.GPU function
264
+ current_device = device
265
+ if IS_ZEROGPU and torch.cuda.is_available():
266
+ current_device = torch.device("cuda")
267
+ inputs = inputs.to(current_device)
268
+ # Only in ZeroGPU mode, we need to move the model to GPU inside the function
269
+ t5_model.to(current_device)
270
+ elif torch.cuda.is_available(): # Local environment with CUDA
271
+ inputs = inputs.to(current_device)
272
+
273
+ # Use more conservative generation settings with error handling
274
+ try:
275
+ outputs = t5_model.generate(
276
+ **inputs,
277
+ max_length=512,
278
+ num_beams=4, # Reduced from 5 to be more memory-efficient
279
+ early_stopping=True,
280
+ no_repeat_ngram_size=2,
281
+ length_penalty=1.0
282
+ )
283
+ rephrased_text = t5_tokenizer.decode(outputs[0], skip_special_tokens=True)
284
+
285
+ # Verify the output is reasonable
286
+ if not rephrased_text or len(rephrased_text) < 3:
287
+ logger.warning(f"T5 model produced unusable output: '{rephrased_text}'")
288
+ return text # Return original if output is too short
289
+
290
+ # Move model back to CPU if in ZeroGPU mode - to free GPU memory
291
+ if IS_ZEROGPU and torch.cuda.is_available():
292
+ t5_model.to(torch.device("cpu"))
293
+
294
+ return rephrased_text.strip()
295
+
296
+ except RuntimeError as e:
297
+ # Handle potential CUDA out of memory error
298
+ if "CUDA out of memory" in str(e):
299
+ logger.warning("CUDA out of memory in T5 model. Trying with smaller beam size...")
300
+ # Try again with smaller beam size
301
+ outputs = t5_model.generate(
302
+ **inputs,
303
+ max_length=512,
304
+ num_beams=2, # Use smaller beam size
305
+ early_stopping=True
306
+ )
307
+ rephrased_text = t5_tokenizer.decode(outputs[0], skip_special_tokens=True)
308
+
309
+ # Move model back to CPU if in ZeroGPU mode - to free GPU memory
310
+ if IS_ZEROGPU and torch.cuda.is_available():
311
+ t5_model.to(torch.device("cpu"))
312
+
313
+ return rephrased_text.strip()
314
+ else:
315
+ raise e # Re-raise if it's not a memory issue
316
+
317
+ except Exception as e:
318
+ error_msg = f"Error in rephrasing: {str(e)}"
319
+ logger.error(error_msg)
320
+ # Make sure model is on CPU if in ZeroGPU mode - to free GPU memory
321
+ if IS_ZEROGPU and torch.cuda.is_available():
322
+ try:
323
+ t5_model.to(torch.device("cpu"))
324
+ except:
325
+ pass
326
+ return text # Return original text if rephrasing fails
327
+
328
+ @spaces.GPU
329
+ def text_to_speech(text):
330
+ """
331
+ Convert text to speech using SpeechT5
332
+ """
333
+ if not models_loaded:
334
+ return None
335
+
336
+ try:
337
+ # Create a temporary file path to save the audio
338
+ temp_file = f"temp_tts_output_{int(time.time())}.wav"
339
+
340
+ # Process the text input
341
+ inputs = tts_processor(text=text, return_tensors="pt")
342
+
343
+ # In ZeroGPU, move to GPU here inside the spaces.GPU function
344
+ current_device = device
345
+ if IS_ZEROGPU and torch.cuda.is_available():
346
+ current_device = torch.device("cuda")
347
+ inputs = inputs.to(current_device)
348
+ # Only in ZeroGPU mode, we need to move the models to GPU inside the function
349
+ tts_model.to(current_device)
350
+ vocoder.to(current_device)
351
+ speaker_embeddings_local = speaker_embeddings.to(current_device)
352
+ elif torch.cuda.is_available(): # Local environment with CUDA
353
+ inputs = inputs.to(current_device)
354
+ speaker_embeddings_local = speaker_embeddings
355
+ else:
356
+ speaker_embeddings_local = speaker_embeddings
357
+
358
+ # Generate speech with a fixed speaker embedding
359
+ speech = tts_model.generate_speech(
360
+ inputs["input_ids"],
361
+ speaker_embeddings_local,
362
+ vocoder=vocoder
363
+ )
364
+
365
+ # Convert from PyTorch tensor to NumPy array
366
+ speech_np = speech.cpu().numpy()
367
+
368
+ # Move models back to CPU if in ZeroGPU mode - to free GPU memory
369
+ if IS_ZEROGPU and torch.cuda.is_available():
370
+ tts_model.to(torch.device("cpu"))
371
+ vocoder.to(torch.device("cpu"))
372
+
373
+ # Save as WAV file (sampling rate is 16kHz for SpeechT5)
374
+ write_wav(temp_file, 16000, speech_np)
375
+
376
+ return temp_file
377
+ except Exception as e:
378
+ error_msg = f"Error in text-to-speech conversion: {str(e)}"
379
+ logger.error(error_msg)
380
+ # Make sure models are on CPU if in ZeroGPU mode - to free GPU memory
381
+ if IS_ZEROGPU and torch.cuda.is_available():
382
+ try:
383
+ tts_model.to(torch.device("cpu"))
384
+ vocoder.to(torch.device("cpu"))
385
+ except:
386
+ pass
387
+ return None
388
+
389
+ def text_analysis(input_text, threshold=0.5):
390
+ """
391
+ Analyze text for profanity with adjustable threshold
392
+ """
393
+ if not models_loaded:
394
+ return "Models not loaded yet. Please wait for initialization to complete.", None, None
395
+
396
+ try:
397
+ # Detect profanity with the given threshold
398
+ result = detect_profanity(input_text, threshold=threshold)
399
+
400
+ # Handle error case
401
+ if "error" in result:
402
+ return result["error"], None, None
403
+
404
+ # Process results
405
+ if result["profanity"]:
406
+ clean_text = rephrase_profanity(input_text)
407
+ profane_words_str = ", ".join(result["profane_words"])
408
+
409
+ toxicity_score = result["score"]
410
+
411
+ classification = (
412
+ "Severe Toxicity" if toxicity_score >= 0.7 else
413
+ "Moderate Toxicity" if toxicity_score >= 0.5 else
414
+ "Mild Toxicity" if toxicity_score >= 0.35 else
415
+ "Minimal Toxicity" if toxicity_score >= 0.2 else
416
+ "No Toxicity"
417
+ )
418
+
419
+ # Generate audio for the rephrased text
420
+ audio_output = text_to_speech(clean_text)
421
+
422
+ return (
423
+ f"Profanity Score: {result['score']:.4f}\n\n"
424
+ f"Profane: {result['profanity']}\n"
425
+ f"Classification: {classification}\n"
426
+ f"Detected Profane Words: {profane_words_str}\n\n"
427
+ f"Reworded: {clean_text}"
428
+ ), result["highlighted_text"], audio_output
429
+ else:
430
+ # If no profanity detected, just convert the original text to speech
431
+ audio_output = text_to_speech(input_text)
432
+
433
+ return (
434
+ f"Profanity Score: {result['score']:.4f}\n"
435
+ f"Profane: {result['profanity']}\n"
436
+ f"Classification: No Toxicity"
437
+ ), None, audio_output
438
+ except Exception as e:
439
+ error_msg = f"Error in text analysis: {str(e)}\n{traceback.format_exc()}"
440
+ logger.error(error_msg)
441
+ return error_msg, None, None
442
+
443
+ # ZeroGPU decorator with custom duration: Allocates GPU for up to 120 seconds to handle longer audio processing.
444
+ # Longer durations ensure processing isn't cut off, while shorter durations improve queue priority.
445
+ @spaces.GPU(duration=120)
446
+ def analyze_audio(audio_path, threshold=0.5):
447
+ """
448
+ Analyze audio for profanity with adjustable threshold
449
+ """
450
+ if not models_loaded:
451
+ return "Models not loaded yet. Please wait for initialization to complete.", None, None
452
+
453
+ if not audio_path:
454
+ return "No audio provided.", None, None
455
+
456
+ try:
457
+ # In ZeroGPU mode, models need to be moved to GPU
458
+ if IS_ZEROGPU and torch.cuda.is_available():
459
+ current_device = torch.device("cuda")
460
+ whisper_model.to(current_device)
461
+
462
+ # Transcribe audio
463
+ result = whisper_model.transcribe(audio_path, fp16=torch.cuda.is_available())
464
+ text = result["text"]
465
+
466
+ # Move whisper model back to CPU if in ZeroGPU mode
467
+ if IS_ZEROGPU and torch.cuda.is_available():
468
+ whisper_model.to(torch.device("cpu"))
469
+
470
+ # Detect profanity with user-defined threshold
471
+ analysis = detect_profanity(text, threshold=threshold)
472
+
473
+ # Handle error case
474
+ if "error" in analysis:
475
+ return f"Error during analysis: {analysis['error']}\nTranscription: {text}", None, None
476
+
477
+ if analysis["profanity"]:
478
+ clean_text = rephrase_profanity(text)
479
+ else:
480
+ clean_text = text
481
+
482
+ # Generate audio for the rephrased text
483
+ audio_output = text_to_speech(clean_text)
484
+
485
+ return (
486
+ f"Transcription: {text}\n\n"
487
+ f"Profanity Score: {analysis['score']:.4f}\n"
488
+ f"Profane: {'Yes' if analysis['profanity'] else 'No'}\n"
489
+ f"Classification: {'Severe Toxicity' if analysis['score'] >= 0.7 else 'Moderate Toxicity' if analysis['score'] >= 0.5 else 'Mild Toxicity' if analysis['score'] >= 0.35 else 'Minimal Toxicity' if analysis['score'] >= 0.2 else 'No Toxicity'}\n"
490
+ f"Profane Words: {', '.join(analysis['profane_words']) if analysis['profanity'] else 'None'}\n\n"
491
+ f"Reworded: {clean_text}"
492
+ ), analysis["highlighted_text"] if analysis["profanity"] else None, audio_output
493
+ except Exception as e:
494
+ error_msg = f"Error in audio analysis: {str(e)}\n{traceback.format_exc()}"
495
+ logger.error(error_msg)
496
+ # Make sure models are on CPU if in ZeroGPU mode
497
+ if IS_ZEROGPU and torch.cuda.is_available():
498
+ try:
499
+ whisper_model.to(torch.device("cpu"))
500
+ except:
501
+ pass
502
+ return error_msg, None, None
503
+
504
+ # Global variables to store streaming results
505
+ stream_results = {
506
+ "transcript": "",
507
+ "profanity_info": "",
508
+ "clean_text": "",
509
+ "audio_output": None
510
+ }
511
+
512
+ @spaces.GPU
513
+ def process_stream_chunk(audio_chunk):
514
+ """Process an audio chunk from the streaming interface"""
515
+ global stream_results, processing_active
516
+
517
+ if not processing_active or not models_loaded:
518
+ return stream_results["transcript"], stream_results["profanity_info"], stream_results["clean_text"], stream_results["audio_output"]
519
+
520
+ try:
521
+ # The format of audio_chunk from Gradio streaming can vary
522
+ # It can be: (numpy_array, sample_rate), (filepath, sample_rate, numpy_array) or just numpy_array
523
+ # Let's handle all possible cases
524
+
525
+ if audio_chunk is None:
526
+ # No audio received
527
+ return stream_results["transcript"], stream_results["profanity_info"], stream_results["clean_text"], stream_results["audio_output"]
528
+
529
+ # Different Gradio versions return different formats
530
+ temp_file = None
531
+
532
+ if isinstance(audio_chunk, tuple):
533
+ if len(audio_chunk) == 2:
534
+ # Format: (numpy_array, sample_rate)
535
+ samples, sample_rate = audio_chunk
536
+ temp_file = f"temp_stream_{int(time.time())}.wav"
537
+ write_wav(temp_file, sample_rate, samples)
538
+ elif len(audio_chunk) == 3:
539
+ # Format: (filepath, sample_rate, numpy_array)
540
+ filepath, sample_rate, samples = audio_chunk
541
+ # Use the provided filepath if it exists
542
+ if os.path.exists(filepath):
543
+ temp_file = filepath
544
+ else:
545
+ # Create our own file
546
+ temp_file = f"temp_stream_{int(time.time())}.wav"
547
+ write_wav(temp_file, sample_rate, samples)
548
+ elif isinstance(audio_chunk, np.ndarray):
549
+ # Just a numpy array, assume sample rate of 16000 for Whisper
550
+ samples = audio_chunk
551
+ sample_rate = 16000
552
+ temp_file = f"temp_stream_{int(time.time())}.wav"
553
+ write_wav(temp_file, sample_rate, samples)
554
+ elif isinstance(audio_chunk, str) and os.path.exists(audio_chunk):
555
+ # It's a filepath
556
+ temp_file = audio_chunk
557
+ else:
558
+ # Unknown format
559
+ stream_results["profanity_info"] = f"Error: Unknown audio format: {type(audio_chunk)}"
560
+ return stream_results["transcript"], stream_results["profanity_info"], stream_results["clean_text"], stream_results["audio_output"]
561
+
562
+ # Make sure we have a valid file to process
563
+ if not temp_file or not os.path.exists(temp_file):
564
+ stream_results["profanity_info"] = "Error: Failed to create audio file for processing"
565
+ return stream_results["transcript"], stream_results["profanity_info"], stream_results["clean_text"], stream_results["audio_output"]
566
+
567
+ # In ZeroGPU mode, move whisper model to GPU
568
+ if IS_ZEROGPU and torch.cuda.is_available():
569
+ current_device = torch.device("cuda")
570
+ whisper_model.to(current_device)
571
+
572
+ # Process with Whisper
573
+ result = whisper_model.transcribe(temp_file, fp16=torch.cuda.is_available())
574
+ transcript = result["text"].strip()
575
+
576
+ # Move whisper model back to CPU if in ZeroGPU mode
577
+ if IS_ZEROGPU and torch.cuda.is_available():
578
+ whisper_model.to(torch.device("cpu"))
579
+
580
+ # Skip processing if transcript is empty
581
+ if not transcript:
582
+ # Clean up temp file if we created it
583
+ if temp_file and temp_file.startswith("temp_stream_") and os.path.exists(temp_file):
584
+ try:
585
+ os.remove(temp_file)
586
+ except:
587
+ pass
588
+ # Return current state, but update profanity info
589
+ stream_results["profanity_info"] = "No speech detected. Keep talking..."
590
+ return stream_results["transcript"], stream_results["profanity_info"], stream_results["clean_text"], stream_results["audio_output"]
591
+
592
+ # Update transcript
593
+ stream_results["transcript"] = transcript
594
+
595
+ # Analyze for profanity
596
+ analysis = detect_profanity(transcript, threshold=0.5)
597
+
598
+ # Check if profanity was detected
599
+ if analysis.get("profanity", False):
600
+ profane_words = ", ".join(analysis.get("profane_words", []))
601
+ stream_results["profanity_info"] = f"Profanity Detected (Score: {analysis['score']:.2f})\nProfane Words: {profane_words}"
602
+
603
+ # Rephrase to clean text
604
+ clean_text = rephrase_profanity(transcript)
605
+ stream_results["clean_text"] = clean_text
606
+
607
+ # Create audio from cleaned text
608
+ audio_file = text_to_speech(clean_text)
609
+ if audio_file:
610
+ stream_results["audio_output"] = audio_file
611
+ else:
612
+ stream_results["profanity_info"] = f"No Profanity Detected (Score: {analysis['score']:.2f})"
613
+ stream_results["clean_text"] = transcript
614
+
615
+ # Use original text for audio if no profanity
616
+ audio_file = text_to_speech(transcript)
617
+ if audio_file:
618
+ stream_results["audio_output"] = audio_file
619
+
620
+ # Clean up temporary file if we created it
621
+ if temp_file and temp_file.startswith("temp_stream_") and os.path.exists(temp_file):
622
+ try:
623
+ os.remove(temp_file)
624
+ except:
625
+ pass
626
+
627
+ return stream_results["transcript"], stream_results["profanity_info"], stream_results["clean_text"], stream_results["audio_output"]
628
+
629
+ except Exception as e:
630
+ error_msg = f"Error processing streaming audio: {str(e)}\n{traceback.format_exc()}"
631
+ logger.error(error_msg)
632
+
633
+ # Make sure all models are on CPU if in ZeroGPU mode
634
+ if IS_ZEROGPU and torch.cuda.is_available():
635
+ try:
636
+ whisper_model.to(torch.device("cpu"))
637
+ profanity_model.to(torch.device("cpu"))
638
+ t5_model.to(torch.device("cpu"))
639
+ tts_model.to(torch.device("cpu"))
640
+ vocoder.to(torch.device("cpu"))
641
+ except:
642
+ pass
643
+
644
+ # Update profanity info with error message
645
+ stream_results["profanity_info"] = f"Error: {str(e)}"
646
+
647
+ return stream_results["transcript"], stream_results["profanity_info"], stream_results["clean_text"], stream_results["audio_output"]
648
+
649
+ def start_streaming():
650
+ """Start the real-time audio processing"""
651
+ global processing_active, stream_results
652
+
653
+ if not models_loaded:
654
+ return "Models not loaded yet. Please wait for initialization to complete."
655
+
656
+ if processing_active:
657
+ return "Streaming is already active."
658
+
659
+ # Reset results
660
+ stream_results = {
661
+ "transcript": "",
662
+ "profanity_info": "Waiting for audio input...",
663
+ "clean_text": "",
664
+ "audio_output": None
665
+ }
666
+
667
+ processing_active = True
668
+ logger.info("Started real-time audio processing")
669
+ return "Started real-time audio processing. Speak into your microphone."
670
+
671
+ def stop_streaming():
672
+ """Stop the real-time audio processing"""
673
+ global processing_active
674
+
675
+ if not processing_active:
676
+ return "Streaming is not active."
677
+
678
+ processing_active = False
679
+ return "Stopped real-time audio processing."
680
+
681
+ def create_ui():
682
+ """Create the Gradio UI"""
683
+ # Simple CSS for styling
684
+ css = """
685
+ /* Fix for dark mode text visibility */
686
+ .dark .gr-input,
687
+ .dark textarea,
688
+ .dark .gr-textbox,
689
+ .dark [data-testid="textbox"] {
690
+ color: white !important;
691
+ background-color: #2c303b !important;
692
+ }
693
+
694
+ .dark .gr-box,
695
+ .dark .gr-form,
696
+ .dark .gr-panel,
697
+ .dark .gr-block {
698
+ color: white !important;
699
+ }
700
+
701
+ /* Highlighted text container - with dark mode fixes */
702
+ .highlighted-text {
703
+ border: 1px solid #ddd;
704
+ border-radius: 5px;
705
+ padding: 10px;
706
+ margin: 10px 0;
707
+ background-color: #f9f9f9;
708
+ font-family: sans-serif;
709
+ max-height: 300px;
710
+ overflow-y: auto;
711
+ color: #333 !important; /* Ensure text is dark for light mode */
712
+ }
713
+
714
+ /* Dark mode specific styling for highlighted text */
715
+ .dark .highlighted-text {
716
+ background-color: #2c303b !important;
717
+ color: #ffffff !important;
718
+ border-color: #4a4f5a !important;
719
+ }
720
+
721
+ /* Make sure text in the highlighted container remains visible in both themes */
722
+ .highlighted-text, .dark .highlighted-text {
723
+ color-scheme: light dark;
724
+ }
725
+
726
+ /* Loading animation */
727
+ .loading {
728
+ display: inline-block;
729
+ width: 20px;
730
+ height: 20px;
731
+ border: 3px solid rgba(0,0,0,.3);
732
+ border-radius: 50%;
733
+ border-top-color: #3498db;
734
+ animation: spin 1s ease-in-out infinite;
735
+ }
736
+
737
+ @keyframes spin {
738
+ to { transform: rotate(360deg); }
739
+ }
740
+ """
741
+
742
+ # Create a custom theme based on Soft but explicitly set to light mode
743
+ light_theme = gr.themes.Soft(
744
+ primary_hue="blue",
745
+ secondary_hue="blue",
746
+ neutral_hue="gray"
747
+ )
748
+
749
+ # Set theme to light mode and disable theme switching
750
+ with gr.Blocks(css=css, theme=light_theme, analytics_enabled=False) as ui:
751
+ # Model initialization
752
+ init_status = gr.State("")
753
+
754
+ gr.Markdown(
755
+ """
756
+ # Profanity Detection & Replacement System
757
+ Detect, rephrase, and listen to cleaned content from text or audio!
758
+ """,
759
+ elem_classes="header"
760
+ )
761
+
762
+ # The rest of your UI code remains unchanged...
763
+ # Initialize models button with status indicators
764
+ with gr.Row():
765
+ with gr.Column(scale=3):
766
+ init_button = gr.Button("Initialize Models", variant="primary")
767
+ init_output = gr.Textbox(label="Initialization Status", interactive=False)
768
+ with gr.Column(scale=1):
769
+ model_status = gr.HTML(
770
+ """<div style="text-align: center; padding: 5px;">
771
+ <p><b>Model Status:</b> <span style="color: #e74c3c;">Not Loaded</span></p>
772
+ </div>"""
773
+ )
774
+
775
+ # Global sensitivity slider
776
+ sensitivity = gr.Slider(
777
+ minimum=0.2,
778
+ maximum=0.95,
779
+ value=0.5,
780
+ step=0.05,
781
+ label="Profanity Detection Sensitivity",
782
+ info="Lower values are more permissive, higher values are more strict"
783
+ )
784
+
785
+ with gr.Row():
786
+ with gr.Column(scale=3):
787
+ gr.Markdown("### Choose an Input Method")
788
+
789
+ # Text Analysis
790
+ with gr.Tabs():
791
+ with gr.TabItem("Text Analysis", elem_id="text-tab"):
792
+ with gr.Row():
793
+ text_input = gr.Textbox(
794
+ label="Enter Text",
795
+ placeholder="Type your text here...",
796
+ lines=5,
797
+ elem_classes="textbox"
798
+ )
799
+ with gr.Row():
800
+ text_button = gr.Button("Analyze Text", variant="primary")
801
+ clear_button = gr.Button("Clear", variant="secondary")
802
+
803
+ with gr.Row():
804
+ with gr.Column(scale=2):
805
+ text_output = gr.Textbox(label="Results", lines=10)
806
+ highlighted_output = gr.HTML(label="Detected Profanity", elem_classes="highlighted-text")
807
+ with gr.Column(scale=1):
808
+ text_audio_output = gr.Audio(label="Rephrased Audio", type="filepath")
809
+
810
+ # Audio Analysis
811
+ with gr.TabItem("Audio Analysis", elem_id="audio-tab"):
812
+ gr.Markdown("### Upload or Record Audio")
813
+ audio_input = gr.Audio(
814
+ label="Audio Input",
815
+ type="filepath",
816
+ sources=["microphone", "upload"]
817
+ #waveform_options=gr.WaveformOptions(waveform_color="#4a90e2")
818
+ )
819
+ with gr.Row():
820
+ audio_button = gr.Button("Analyze Audio", variant="primary")
821
+ clear_audio_button = gr.Button("Clear", variant="secondary")
822
+
823
+ with gr.Row():
824
+ with gr.Column(scale=2):
825
+ audio_output = gr.Textbox(label="Results", lines=10, show_copy_button=True)
826
+ audio_highlighted_output = gr.HTML(label="Detected Profanity", elem_classes="highlighted-text")
827
+ with gr.Column(scale=1):
828
+ clean_audio_output = gr.Audio(label="Rephrased Audio", type="filepath")
829
+
830
+ # Real-time Streaming
831
+ with gr.TabItem("Real-time Streaming", elem_id="streaming-tab"):
832
+ gr.Markdown("### Real-time Audio Processing")
833
+ gr.Markdown("Enable real-time audio processing to filter profanity as you speak.")
834
+
835
+ with gr.Row():
836
+ with gr.Column(scale=1):
837
+ start_stream_button = gr.Button("Start Real-time Processing", variant="primary")
838
+ stop_stream_button = gr.Button("Stop Real-time Processing", variant="secondary")
839
+ stream_status = gr.Textbox(label="Streaming Status", value="Inactive", interactive=False)
840
+
841
+ # Add microphone input specifically for streaming
842
+ stream_audio_input = gr.Audio(
843
+ label="Streaming Microphone Input",
844
+ type="filepath",
845
+ sources=["microphone"],
846
+ streaming=True
847
+ #waveform_options=gr.WaveformOptions(waveform_color="#4a90e2")
848
+ )
849
+
850
+ with gr.Column(scale=2):
851
+ # Add elements to display streaming results
852
+ stream_transcript = gr.Textbox(label="Live Transcription", lines=2)
853
+ stream_profanity_info = gr.Textbox(label="Profanity Detection", lines=2)
854
+ stream_clean_text = gr.Textbox(label="Clean Text", lines=2)
855
+ # Element to play the clean audio
856
+ stream_audio_output = gr.Audio(label="Clean Audio Output", type="filepath")
857
+
858
+ gr.Markdown("""
859
+ ### How Real-time Streaming Works
860
+ 1. Click "Start Real-time Processing" to begin
861
+ 2. Use the microphone input to speak
862
+ 3. The system will process audio in real-time, detect and clean profanity
863
+ 4. You'll see the transcription, profanity info, and clean output appear above
864
+ 5. Click "Stop Real-time Processing" when finished
865
+
866
+ Note: This feature requires microphone access and may have some latency.
867
+ """)
868
+
869
+ # Event handlers
870
+ def update_model_status(status_text):
871
+ """Update both the status text and the visual indicator"""
872
+ if "successfully" in status_text.lower():
873
+ status_html = """<div style="text-align: center; padding: 5px;">
874
+ <p><b>Model Status:</b> <span style="color: #2ecc71;">Loaded ✓</span></p>
875
+ </div>"""
876
+ elif "error" in status_text.lower():
877
+ status_html = """<div style="text-align: center; padding: 5px;">
878
+ <p><b>Model Status:</b> <span style="color: #e74c3c;">Error ✗</span></p>
879
+ </div>"""
880
+ else:
881
+ status_html = """<div style="text-align: center; padding: 5px;">
882
+ <p><b>Model Status:</b> <span style="color: #f39c12;">Loading...</span></p>
883
+ </div>"""
884
+ return status_text, status_html
885
+
886
+ init_button.click(
887
+ lambda: update_model_status("Loading models, please wait..."),
888
+ inputs=[],
889
+ outputs=[init_output, model_status]
890
+ ).then(
891
+ load_models,
892
+ inputs=[],
893
+ outputs=[init_output]
894
+ ).then(
895
+ update_model_status,
896
+ inputs=[init_output],
897
+ outputs=[init_output, model_status]
898
+ )
899
+
900
+ text_button.click(
901
+ text_analysis,
902
+ inputs=[text_input, sensitivity],
903
+ outputs=[text_output, highlighted_output, text_audio_output]
904
+ )
905
+
906
+ clear_button.click(
907
+ lambda: [None, None, None],
908
+ inputs=None,
909
+ outputs=[text_input, highlighted_output, text_audio_output]
910
+ )
911
+
912
+ audio_button.click(
913
+ analyze_audio,
914
+ inputs=[audio_input, sensitivity],
915
+ outputs=[audio_output, audio_highlighted_output, clean_audio_output]
916
+ )
917
+
918
+ clear_audio_button.click(
919
+ lambda: [None, None, None, None],
920
+ inputs=None,
921
+ outputs=[audio_input, audio_output, audio_highlighted_output, clean_audio_output]
922
+ )
923
+
924
+ start_stream_button.click(
925
+ start_streaming,
926
+ inputs=[],
927
+ outputs=[stream_status]
928
+ )
929
+
930
+ stop_stream_button.click(
931
+ stop_streaming,
932
+ inputs=[],
933
+ outputs=[stream_status]
934
+ )
935
+
936
+ # Connect the streaming audio input to our processing function
937
+ # First function to debug the audio chunk format
938
+ def debug_audio_format(audio_chunk):
939
+ """Debug function to log audio format"""
940
+ format_info = f"Type: {type(audio_chunk)}"
941
+ if isinstance(audio_chunk, tuple):
942
+ format_info += f", Length: {len(audio_chunk)}"
943
+ for i, item in enumerate(audio_chunk):
944
+ format_info += f", Item {i} type: {type(item)}"
945
+ logger.info(f"Audio chunk format: {format_info}")
946
+ return audio_chunk
947
+
948
+ # Use the stream method with preprocessor for debugging
949
+ stream_audio_input.stream(
950
+ fn=process_stream_chunk,
951
+ inputs=[stream_audio_input],
952
+ outputs=[stream_transcript, stream_profanity_info, stream_clean_text, stream_audio_output],
953
+ preprocess=debug_audio_format
954
+ )
955
+
956
+ return ui
957
+
958
+ if __name__ == "__main__":
959
+ # Set environment variable to avoid OpenMP conflicts
960
+ os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
961
+
962
+ # Create and launch the UI
963
+ ui = create_ui()
964
+ ui.launch(server_name="0.0.0.0", share=True)