File size: 2,736 Bytes
85ce2c2
 
 
 
 
f00dfca
85ce2c2
f00dfca
85ce2c2
 
 
 
 
 
 
 
 
f00dfca
2dc0441
85ce2c2
f00dfca
 
85ce2c2
f00dfca
 
 
85ce2c2
f00dfca
 
 
85ce2c2
f00dfca
 
 
85ce2c2
f00dfca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85ce2c2
2dc0441
 
f00dfca
 
 
 
2dc0441
f00dfca
85ce2c2
f00dfca
 
 
 
 
85ce2c2
f00dfca
 
85ce2c2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import gradio as gr
import json
from transformers import pipeline
from gtts import gTTS
import os
import re

# Load personas
def load_personas():
    personas = []
    for i in range(1, 7):
        with open(f'personas/persona_{i}.json', 'r') as f:
            personas.append(json.load(f))
    return personas

personas = load_personas()

# Use distilgpt2 for speed
bot = pipeline("text-generation", model="distilgpt2")

# Swear filter list
BAD_WORDS = ["fuck", "shit", "bitch", "asshole"]

def clean_response(text):
    pattern = re.compile(r"|".join(BAD_WORDS), re.IGNORECASE)
    return pattern.sub("[censored]", text)

# Global state
history = ""
round_number = 0

# Debate function
def next_round(topic, consensus_level):
    global round_number, history

    if round_number >= 6:
        return "Debate complete.", None, gr.Button(visible=False)

    persona = personas[round_number]

    # Adjust prompt based on consensus slider
    if consensus_level > 50:
        style = "Aim to find agreement and common ground."
    else:
        style = "Challenge previous points and offer a critical view."

    prompt = f"{persona['prompt_style']} {style}\n{history}\n{persona['name']}:"
    response = bot(prompt, max_length=50, do_sample=True)[0]['generated_text']
    response_text = response[len(prompt):].strip().split("\n")[0]

    # Filter bad words
    response_text = clean_response(response_text)

    history += f"{persona['name']}: {response_text}\n"

    # Generate TTS
    tts = gTTS(response_text, lang='en')
    audio_file = f"temp_{round_number}.mp3"
    tts.save(audio_file)

    round_number += 1

    if round_number >= 6:
        button_visibility = gr.Button(visible=False)
    else:
        button_visibility = gr.Button(visible=True)

    return f"Round {round_number}: {persona['name']} says:\n" + response_text, audio_file, button_visibility

with gr.Blocks() as iface:
    gr.Markdown("# AI Persona Debate")
    topic = gr.Textbox(label="Enter Debate Question")
    consensus = gr.Slider(0, 100, value=50, label="Disagreement <-> Consensus")
    start_btn = gr.Button("Start Debate")
    output_text = gr.Textbox(label="Debate Response")
    output_audio = gr.Audio(label="Voice", autoplay=True)
    next_btn = gr.Button("Trigger Next Round", visible=False)

    def start_debate(user_topic, consensus_level):
        global history, round_number
        history = f"Debate Topic: {user_topic}\n"
        round_number = 0
        return next_round(user_topic, consensus_level)

    start_btn.click(start_debate, inputs=[topic, consensus], outputs=[output_text, output_audio, next_btn])
    next_btn.click(next_round, inputs=[topic, consensus], outputs=[output_text, output_audio, next_btn])

iface.launch()