Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -3,8 +3,9 @@ import json
|
|
| 3 |
from transformers import pipeline
|
| 4 |
from gtts import gTTS
|
| 5 |
import os
|
|
|
|
| 6 |
|
| 7 |
-
# Load
|
| 8 |
def load_personas():
|
| 9 |
personas = []
|
| 10 |
for i in range(1, 7):
|
|
@@ -14,37 +15,74 @@ def load_personas():
|
|
| 14 |
|
| 15 |
personas = load_personas()
|
| 16 |
|
| 17 |
-
# Use distilgpt2 for
|
| 18 |
bot = pipeline("text-generation", model="distilgpt2")
|
| 19 |
|
| 20 |
-
#
|
| 21 |
-
|
| 22 |
-
history = f"Debate Topic: {topic}\n"
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
response_text = response[len(prompt):].strip().split("\n")[0]
|
| 28 |
|
| 29 |
-
|
|
|
|
|
|
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
with gr.Blocks() as iface:
|
| 38 |
gr.Markdown("# AI Persona Debate")
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
|
|
|
| 42 |
output_audio = gr.Audio(label="Voice", autoplay=True)
|
|
|
|
| 43 |
|
| 44 |
-
def
|
| 45 |
-
|
| 46 |
-
|
|
|
|
|
|
|
| 47 |
|
| 48 |
-
|
|
|
|
| 49 |
|
| 50 |
iface.launch()
|
|
|
|
| 3 |
from transformers import pipeline
|
| 4 |
from gtts import gTTS
|
| 5 |
import os
|
| 6 |
+
import re
|
| 7 |
|
| 8 |
+
# Load personas
|
| 9 |
def load_personas():
|
| 10 |
personas = []
|
| 11 |
for i in range(1, 7):
|
|
|
|
| 15 |
|
| 16 |
personas = load_personas()
|
| 17 |
|
| 18 |
+
# Use distilgpt2 for speed
|
| 19 |
bot = pipeline("text-generation", model="distilgpt2")
|
| 20 |
|
| 21 |
+
# Swear filter list
|
| 22 |
+
BAD_WORDS = ["fuck", "shit", "bitch", "asshole"]
|
|
|
|
| 23 |
|
| 24 |
+
def clean_response(text):
|
| 25 |
+
pattern = re.compile(r"|".join(BAD_WORDS), re.IGNORECASE)
|
| 26 |
+
return pattern.sub("[censored]", text)
|
|
|
|
| 27 |
|
| 28 |
+
# Global state
|
| 29 |
+
history = ""
|
| 30 |
+
round_number = 0
|
| 31 |
|
| 32 |
+
# Debate function
|
| 33 |
+
def next_round(topic, consensus_level):
|
| 34 |
+
global round_number, history
|
| 35 |
|
| 36 |
+
if round_number >= 6:
|
| 37 |
+
return "Debate complete.", None, gr.Button(visible=False)
|
| 38 |
+
|
| 39 |
+
persona = personas[round_number]
|
| 40 |
+
|
| 41 |
+
# Adjust prompt based on consensus slider
|
| 42 |
+
if consensus_level > 50:
|
| 43 |
+
style = "Aim to find agreement and common ground."
|
| 44 |
+
else:
|
| 45 |
+
style = "Challenge previous points and offer a critical view."
|
| 46 |
+
|
| 47 |
+
prompt = f"{persona['prompt_style']} {style}\n{history}\n{persona['name']}:"
|
| 48 |
+
response = bot(prompt, max_length=50, do_sample=True)[0]['generated_text']
|
| 49 |
+
response_text = response[len(prompt):].strip().split("\n")[0]
|
| 50 |
+
|
| 51 |
+
# Filter bad words
|
| 52 |
+
response_text = clean_response(response_text)
|
| 53 |
+
|
| 54 |
+
history += f"{persona['name']}: {response_text}\n"
|
| 55 |
+
|
| 56 |
+
# Generate TTS
|
| 57 |
+
tts = gTTS(response_text, lang='en')
|
| 58 |
+
audio_file = f"temp_{round_number}.mp3"
|
| 59 |
+
tts.save(audio_file)
|
| 60 |
+
|
| 61 |
+
round_number += 1
|
| 62 |
+
|
| 63 |
+
if round_number >= 6:
|
| 64 |
+
button_visibility = gr.Button(visible=False)
|
| 65 |
+
else:
|
| 66 |
+
button_visibility = gr.Button(visible=True)
|
| 67 |
+
|
| 68 |
+
return f"Round {round_number}: {persona['name']} says:\n" + response_text, audio_file, button_visibility
|
| 69 |
|
| 70 |
with gr.Blocks() as iface:
|
| 71 |
gr.Markdown("# AI Persona Debate")
|
| 72 |
+
topic = gr.Textbox(label="Enter Debate Question")
|
| 73 |
+
consensus = gr.Slider(0, 100, value=50, label="Disagreement <-> Consensus")
|
| 74 |
+
start_btn = gr.Button("Start Debate")
|
| 75 |
+
output_text = gr.Textbox(label="Debate Response")
|
| 76 |
output_audio = gr.Audio(label="Voice", autoplay=True)
|
| 77 |
+
next_btn = gr.Button("Trigger Next Round", visible=False)
|
| 78 |
|
| 79 |
+
def start_debate(user_topic, consensus_level):
|
| 80 |
+
global history, round_number
|
| 81 |
+
history = f"Debate Topic: {user_topic}\n"
|
| 82 |
+
round_number = 0
|
| 83 |
+
return next_round(user_topic, consensus_level)
|
| 84 |
|
| 85 |
+
start_btn.click(start_debate, inputs=[topic, consensus], outputs=[output_text, output_audio, next_btn])
|
| 86 |
+
next_btn.click(next_round, inputs=[topic, consensus], outputs=[output_text, output_audio, next_btn])
|
| 87 |
|
| 88 |
iface.launch()
|