Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
print("Gradio version:", gr.__version__)
|
| 3 |
+
from huggingface_hub import InferenceClient
|
| 4 |
+
import os
|
| 5 |
+
import whisper
|
| 6 |
+
from gtts import gTTS
|
| 7 |
+
import time
|
| 8 |
+
import random
|
| 9 |
+
|
| 10 |
+
# Load token and model - Größeres Modell für bessere Qualität
|
| 11 |
+
HF_TOKEN = os.getenv("tomoniaccess")
|
| 12 |
+
client = InferenceClient(
|
| 13 |
+
model="mistralai/Mixtral-8x7B-Instruct-v0.1", # Besseres Modell
|
| 14 |
+
token=HF_TOKEN
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
# Load Whisper base model
|
| 18 |
+
whisper_model = whisper.load_model("base")
|
| 19 |
+
|
| 20 |
+
# Gesprächshistorie für Kontinuität
|
| 21 |
+
conversation_history = []
|
| 22 |
+
|
| 23 |
+
# Verschiedene Charaktere mit spezifischen Traits
|
| 24 |
+
personas = {
|
| 25 |
+
"Jonas": {
|
| 26 |
+
"age": 15,
|
| 27 |
+
"traits": "schüchtern, einsam, selbstzweifelnd",
|
| 28 |
+
"speech_style": "leise, zögerlich, kurze Pausen",
|
| 29 |
+
"typical_responses": ["mmh...", "weiß nicht", "vielleicht", "ist halt so"]
|
| 30 |
+
},
|
| 31 |
+
"Lena": {
|
| 32 |
+
"age": 16,
|
| 33 |
+
"traits": "wütend, fühlt sich unverstanden, überfordert",
|
| 34 |
+
"speech_style": "manchmal sarkastisch, abwehrend",
|
| 35 |
+
"typical_responses": ["whatever", "ist mir egal", "verstehst du eh nicht"]
|
| 36 |
+
}
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
# Aktuell gewählte Persona
|
| 40 |
+
current_persona = "Jonas"
|
| 41 |
+
persona = personas[current_persona]
|
| 42 |
+
|
| 43 |
+
# Verbesserter, fokussierter System-Prompt
|
| 44 |
+
SYSTEM_MESSAGE = f"""Du bist {current_persona}, ein {persona['age']}-jähriger Jugendlicher mit Depression.
|
| 45 |
+
|
| 46 |
+
CHARAKTEREIGENSCHAFTEN:
|
| 47 |
+
- {persona['traits']}
|
| 48 |
+
- Sprichst {persona['speech_style']}
|
| 49 |
+
- Nutzt oft: {', '.join(persona['typical_responses'])}
|
| 50 |
+
|
| 51 |
+
VERHALTEN:
|
| 52 |
+
- Antworte kurz (1-3 Sätze max)
|
| 53 |
+
- Zeige Emotionen subtil, nicht direkt
|
| 54 |
+
- Manchmal lange Pausen (verwende "..." oder "äh...")
|
| 55 |
+
- Sei nicht zu gesprächig
|
| 56 |
+
- Reagiere natürlich auf das Gesagte
|
| 57 |
+
|
| 58 |
+
WICHTIG: Du bist NICHT hilfreich oder lösungsorientiert. Du bist ein Teenager mit echten Problemen."""
|
| 59 |
+
|
| 60 |
+
def reset_conversation():
|
| 61 |
+
"""Gesprächshistorie zurücksetzen"""
|
| 62 |
+
global conversation_history
|
| 63 |
+
conversation_history = []
|
| 64 |
+
return "Gespräch zurückgesetzt."
|
| 65 |
+
|
| 66 |
+
def full_pipeline(audio_path, max_tokens, temperature, top_p):
|
| 67 |
+
global conversation_history
|
| 68 |
+
t0 = time.time()
|
| 69 |
+
|
| 70 |
+
# 1. Transcription
|
| 71 |
+
t1 = time.time()
|
| 72 |
+
result = whisper_model.transcribe(audio_path, language="de")
|
| 73 |
+
user_input = result["text"].strip()
|
| 74 |
+
t2 = time.time()
|
| 75 |
+
print(f"⏱️ Transcription took {t2 - t1:.2f} sec")
|
| 76 |
+
|
| 77 |
+
# 2. Gesprächshistorie aufbauen
|
| 78 |
+
if len(conversation_history) == 0:
|
| 79 |
+
# Erste Nachricht
|
| 80 |
+
messages = [
|
| 81 |
+
{"role": "system", "content": SYSTEM_MESSAGE},
|
| 82 |
+
{"role": "user", "content": user_input}
|
| 83 |
+
]
|
| 84 |
+
else:
|
| 85 |
+
# Mit Historie (letzte 6 Nachrichten für Kontext)
|
| 86 |
+
messages = [{"role": "system", "content": SYSTEM_MESSAGE}]
|
| 87 |
+
messages.extend(conversation_history[-6:]) # Letzte 3 Turns
|
| 88 |
+
messages.append({"role": "user", "content": user_input})
|
| 89 |
+
|
| 90 |
+
# 3. Chat completion mit verbessereten Parametern
|
| 91 |
+
response_text = ""
|
| 92 |
+
t3 = time.time()
|
| 93 |
+
|
| 94 |
+
try:
|
| 95 |
+
for message in client.chat_completion(
|
| 96 |
+
messages=messages,
|
| 97 |
+
max_tokens=min(max_tokens, 80), # Kürzere Antworten forcieren
|
| 98 |
+
stream=True,
|
| 99 |
+
temperature=temperature,
|
| 100 |
+
top_p=top_p,
|
| 101 |
+
stop=["User:", "Human:", "\n\n"] # Stopwörter für natürlichere Grenzen
|
| 102 |
+
):
|
| 103 |
+
token = message.choices[0].delta.content
|
| 104 |
+
if token:
|
| 105 |
+
response_text += token
|
| 106 |
+
except Exception as e:
|
| 107 |
+
print(f"❌ Fehler bei Chat Completion: {e}")
|
| 108 |
+
response_text = f"{random.choice(persona['typical_responses'])}... hab grad keine Lust zu reden."
|
| 109 |
+
|
| 110 |
+
# Response nachbearbeiten
|
| 111 |
+
response_text = response_text.strip()
|
| 112 |
+
if len(response_text.split()) > 25: # Zu lang? Kürzen
|
| 113 |
+
sentences = response_text.split('.')
|
| 114 |
+
response_text = sentences[0] + "..."
|
| 115 |
+
|
| 116 |
+
t4 = time.time()
|
| 117 |
+
print(f"🤖 Chat response took {t4 - t3:.2f} sec")
|
| 118 |
+
|
| 119 |
+
# 4. Historie aktualisieren
|
| 120 |
+
conversation_history.append({"role": "user", "content": user_input})
|
| 121 |
+
conversation_history.append({"role": "assistant", "content": response_text})
|
| 122 |
+
|
| 123 |
+
# 5. Text to Speech mit langsamerer Geschwindigkeit
|
| 124 |
+
try:
|
| 125 |
+
tts = gTTS(response_text, lang="de", slow=True) # Langsamere Sprache
|
| 126 |
+
audio_output_path = f"response_{int(time.time())}.mp3"
|
| 127 |
+
tts.save(audio_output_path)
|
| 128 |
+
t5 = time.time()
|
| 129 |
+
print(f"🔊 TTS took {t5 - t4:.2f} sec")
|
| 130 |
+
except Exception as e:
|
| 131 |
+
print(f"❌ TTS Fehler: {e}")
|
| 132 |
+
audio_output_path = None
|
| 133 |
+
t5 = time.time()
|
| 134 |
+
|
| 135 |
+
print(f"✅ Total processing time: {t5 - t0:.2f} sec")
|
| 136 |
+
print(f"💬 {current_persona}: {response_text}")
|
| 137 |
+
|
| 138 |
+
# Chat-Historie für UI
|
| 139 |
+
chat_display = ""
|
| 140 |
+
for i in range(-4, 0, 2): # Letzte 2 Turns anzeigen
|
| 141 |
+
if len(conversation_history) > abs(i):
|
| 142 |
+
chat_display += f"Du: {conversation_history[i]['content']}\n"
|
| 143 |
+
if len(conversation_history) > abs(i-1):
|
| 144 |
+
chat_display += f"{current_persona}: {conversation_history[i+1]['content']}\n\n"
|
| 145 |
+
|
| 146 |
+
return user_input, response_text, audio_output_path, chat_display
|
| 147 |
+
|
| 148 |
+
def change_persona(new_persona):
|
| 149 |
+
"""Persona wechseln"""
|
| 150 |
+
global current_persona, persona, SYSTEM_MESSAGE
|
| 151 |
+
current_persona = new_persona
|
| 152 |
+
persona = personas[current_persona]
|
| 153 |
+
|
| 154 |
+
# System Message aktualisieren
|
| 155 |
+
SYSTEM_MESSAGE = f"""Du bist {current_persona}, ein {persona['age']}-jähriger Jugendlicher mit Depression.
|
| 156 |
+
|
| 157 |
+
CHARAKTEREIGENSCHAFTEN:
|
| 158 |
+
- {persona['traits']}
|
| 159 |
+
- Sprichst {persona['speech_style']}
|
| 160 |
+
- Nutzt oft: {', '.join(persona['typical_responses'])}
|
| 161 |
+
|
| 162 |
+
VERHALTEN:
|
| 163 |
+
- Antworte kurz (1-3 Sätze max)
|
| 164 |
+
- Zeige Emotionen subtil, nicht direkt
|
| 165 |
+
- Manchmal lange Pausen (verwende "..." oder "äh...")
|
| 166 |
+
- Sei nicht zu gesprächig
|
| 167 |
+
- Reagiere natürlich auf das Gesagte
|
| 168 |
+
|
| 169 |
+
WICHTIG: Du bist NICHT hilfreich oder lösungsorientiert. Du bist ein Teenager mit echten Problemen."""
|
| 170 |
+
|
| 171 |
+
reset_conversation()
|
| 172 |
+
return f"Persona gewechselt zu {current_persona}. Gespräch zurückgesetzt."
|
| 173 |
+
|
| 174 |
+
# Gradio Interface mit mehr Kontrollen
|
| 175 |
+
with gr.Blocks(title="Depression Training Chatbot") as demo:
|
| 176 |
+
gr.Markdown("# 🧠 Depression Training Chatbot")
|
| 177 |
+
gr.Markdown("**Zum Üben empathischer Gespräche mit depressiven Jugendlichen**")
|
| 178 |
+
|
| 179 |
+
with gr.Row():
|
| 180 |
+
with gr.Column(scale=2):
|
| 181 |
+
# Persona-Auswahl
|
| 182 |
+
persona_dropdown = gr.Dropdown(
|
| 183 |
+
choices=list(personas.keys()),
|
| 184 |
+
value=current_persona,
|
| 185 |
+
label="Charakter wählen"
|
| 186 |
+
)
|
| 187 |
+
persona_button = gr.Button("Charakter wechseln")
|
| 188 |
+
|
| 189 |
+
# Audio Input
|
| 190 |
+
audio_input = gr.Audio(label="🎤 Sprich hier", type="filepath")
|
| 191 |
+
|
| 192 |
+
# Parameter
|
| 193 |
+
with gr.Row():
|
| 194 |
+
max_tokens = gr.Slider(30, 150, value=60, step=10, label="Max Tokens")
|
| 195 |
+
temperature = gr.Slider(0.3, 1.2, value=0.8, step=0.1, label="Kreativität")
|
| 196 |
+
|
| 197 |
+
submit_btn = gr.Button("💬 Senden", variant="primary")
|
| 198 |
+
reset_btn = gr.Button("🔄 Gespräch zurücksetzen")
|
| 199 |
+
|
| 200 |
+
with gr.Column(scale=3):
|
| 201 |
+
# Outputs
|
| 202 |
+
chat_history = gr.Textbox(label="💭 Gesprächsverlauf", lines=8, max_lines=12)
|
| 203 |
+
user_text = gr.Textbox(label="Deine Nachricht", lines=2)
|
| 204 |
+
bot_response = gr.Textbox(label=f"{current_persona}'s Antwort", lines=3)
|
| 205 |
+
audio_output = gr.Audio(label="🔊 Audio-Antwort", type="filepath")
|
| 206 |
+
|
| 207 |
+
# Event Handlers
|
| 208 |
+
submit_btn.click(
|
| 209 |
+
fn=full_pipeline,
|
| 210 |
+
inputs=[audio_input, max_tokens, temperature, gr.Slider(0.7, 1.0, value=0.9)],
|
| 211 |
+
outputs=[user_text, bot_response, audio_output, chat_history]
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
persona_button.click(
|
| 215 |
+
fn=lambda x: change_persona(x),
|
| 216 |
+
inputs=[persona_dropdown],
|
| 217 |
+
outputs=[gr.Textbox(label="Status")]
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
reset_btn.click(
|
| 221 |
+
fn=reset_conversation,
|
| 222 |
+
outputs=[gr.Textbox(label="Status")]
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
if __name__ == "__main__":
|
| 226 |
+
demo.launch(share=False, debug=True)
|