Spaces:
Sleeping
Sleeping
File size: 17,185 Bytes
163b430 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 |
import gradio as gr
import time
import torch
import os
import gc
import psutil
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, VitsModel, VitsTokenizer
import soundfile as sf
import librosa
import tempfile
import google.generativeai as genai
from dotenv import load_dotenv
# Try to load .env file as fallback (for local development)
# HF Spaces will use secrets directly, so this won't override them
load_dotenv()
# Set environment variables for optimization
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Avoid warnings
os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache" # Use tmp for HF Spaces
os.environ["HF_HOME"] = "/tmp/huggingface" # Cache location
def get_memory_usage():
"""Get current memory usage in MB"""
process = psutil.Process(os.getpid())
return process.memory_info().rss / 1024 / 1024
def log_memory(context=""):
"""Log current memory usage"""
memory_mb = get_memory_usage()
print(f"Memory usage {context}: {memory_mb:.1f} MB")
class LatinConversationBot:
def __init__(self):
log_memory("at initialization start")
# Force CPU-only to reduce memory usage on Hugging Face Spaces
self.device = "cpu"
self.message_audio = {}
self.message_texts = {}
# Initialize Gemini using HF Spaces secret or .env fallback
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
# More helpful error message for both HF Spaces and local dev
raise ValueError(
"GEMINI_API_KEY not found!\n"
"For Hugging Face Spaces:\n"
" 1. Go to your Space settings\n"
" 2. Click on 'Repository secrets'\n"
" 3. Add 'GEMINI_API_KEY' with your API key\n"
"For Local Development:\n"
" 1. Create a .env file in the project root\n"
" 2. Add: GEMINI_API_KEY=your_api_key_here"
)
genai.configure(api_key=api_key)
self.gemini_model = genai.GenerativeModel('gemini-flash-latest')
# Model containers
self.asr_processor = None
self.asr_model = None
self.tts_model = None
self.tts_tokenizer = None
self.models_loaded = {"asr": False, "tts": False}
print(f"Bot initialized on device: {self.device}")
# Pre-load models at startup for faster response
try:
print("π Starting model pre-loading...")
self._preload_models()
print("β
All models loaded successfully!")
except Exception as e:
print(f"β οΈ Model pre-loading failed: {e}")
print("Models will be loaded on-demand")
log_memory("after initialization")
def _preload_models(self):
"""Pre-load models at startup but manage memory efficiently"""
try:
# Load ASR first with optimizations
print("π₯ Loading ASR models...")
self.asr_processor = AutoProcessor.from_pretrained(
"ken-z/latin_whisper-small",
cache_dir="/tmp/transformers_cache",
local_files_only=False
)
self.asr_model = AutoModelForSpeechSeq2Seq.from_pretrained(
"ken-z/latin_whisper-small",
torch_dtype=torch.float32,
cache_dir="/tmp/transformers_cache",
low_cpu_mem_usage=True, # Optimize memory usage
local_files_only=False
).to(self.device)
self.models_loaded["asr"] = True
log_memory("after ASR loading")
# Load TTS with optimizations
print("π΅ Loading TTS models...")
self.tts_tokenizer = VitsTokenizer.from_pretrained(
"Ken-Z/latin_SpeechT5",
cache_dir="/tmp/transformers_cache",
local_files_only=False
)
self.tts_model = VitsModel.from_pretrained(
"Ken-Z/latin_SpeechT5",
torch_dtype=torch.float32,
cache_dir="/tmp/transformers_cache",
low_cpu_mem_usage=True, # Optimize memory usage
local_files_only=False
).to(self.device)
self.models_loaded["tts"] = True
log_memory("after TTS loading")
except Exception as e:
print(f"Error in model loading: {e}")
# Fallback to lazy loading
self.models_loaded = {"asr": False, "tts": False}
raise e
def _ensure_asr_loaded(self):
"""Ensure ASR models are loaded"""
if not self.models_loaded["asr"]:
print("Loading ASR models on-demand...")
self.asr_processor = AutoProcessor.from_pretrained("ken-z/latin_whisper-small")
self.asr_model = AutoModelForSpeechSeq2Seq.from_pretrained(
"ken-z/latin_whisper-small",
torch_dtype=torch.float32
).to(self.device)
self.models_loaded["asr"] = True
def _ensure_tts_loaded(self):
"""Ensure TTS models are loaded"""
if not self.models_loaded["tts"]:
print("Loading TTS models on-demand...")
self.tts_tokenizer = VitsTokenizer.from_pretrained("Ken-Z/latin_SpeechT5")
self.tts_model = VitsModel.from_pretrained(
"Ken-Z/latin_SpeechT5",
torch_dtype=torch.float32
).to(self.device)
self.models_loaded["tts"] = True
def _cleanup_models(self):
"""Free up memory by clearing unused models"""
log_memory("before cleanup")
if self.asr_model is not None:
del self.asr_model
self.asr_model = None
self.models_loaded["asr"] = False
if self.asr_processor is not None:
del self.asr_processor
self.asr_processor = None
if self.tts_model is not None:
del self.tts_model
self.tts_model = None
self.models_loaded["tts"] = False
if self.tts_tokenizer is not None:
del self.tts_tokenizer
self.tts_tokenizer = None
gc.collect()
log_memory("after cleanup")
print("Models cleaned up from memory")
def transcribe_audio(self, audio_path):
try:
# Ensure ASR models are loaded
self._ensure_asr_loaded()
audio, _ = librosa.load(audio_path, sr=16000)
input_features = self.asr_processor(audio, sampling_rate=16000, return_tensors="pt").input_features.to(self.device)
with torch.no_grad():
predicted_ids = self.asr_model.generate(input_features)
result = self.asr_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0].strip()
# Clean up tensors but keep models loaded
del input_features, predicted_ids
gc.collect()
return result
except Exception as e:
print(f"ASR Error: {str(e)}")
return f"Error: {str(e)}"
def _call_gemini(self, prompt):
try:
return self.gemini_model.generate_content(prompt).text.strip()
except Exception as e:
print(f"Gemini API error: {e}")
return "Error: Gemini API not available"
def generate_response(self, text):
prompt = f"""You are a Latin conversation bot. Respond ONLY in Latin, keep responses to 1-2 sentences, use proper Classical Latin grammar with proper diacritics, and be conversational.
Examples: "Salve" β "Salve! Quid agis hodie?", "Hello" β "Salve! Latine loquere, quaeso!"
User: {text}
Response:"""
return self._call_gemini(prompt)
def improve_latin_grammar(self, text):
prompt = f"""Fix Latin grammar, diacritics, and word order. Format:
CORRECTED: [corrected text]
EXPLANATION: [brief explanation of fixes only]
Text: {text}"""
response = self._call_gemini(prompt)
# Parse response
corrected = explanation = ""
for line in response.split('\n'):
if line.startswith("CORRECTED:"):
corrected = line[10:].strip()
elif line.startswith("EXPLANATION:"):
explanation = line[12:].strip()
return {
"corrected": corrected or text,
"explanation": explanation or "No explanation provided."
}
def translate_latin(self, text, target_language):
prompt = f"""Translate this Latin text to {target_language}. Return ONLY the translation, no explanations.
Latin text: {text}
{target_language} translation:"""
return self._call_gemini(prompt)
def synthesize_speech(self, text):
try:
# Ensure TTS models are loaded
self._ensure_tts_loaded()
inputs = self.tts_tokenizer(text, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
speech = self.tts_model(**inputs).waveform.squeeze().cpu().numpy()
# Clean up tensors but keep models loaded
del inputs
gc.collect()
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
sf.write(tmp_file.name, speech, samplerate=16000)
return tmp_file.name
except Exception as e:
print(f"TTS error: {e}")
return None
bot_instance = LatinConversationBot()
def add_message(history, message):
for file_info in message["files"]:
file_path = file_info.path if hasattr(file_info, 'path') else file_info
if file_path.endswith(('.wav', '.mp3', '.m4a', '.ogg', '.flac')):
transcription = bot_instance.transcribe_audio(file_path)
history.append({"role": "user", "content": f"π€ {transcription}"})
if message["text"] and message["text"].strip():
history.append({"role": "user", "content": message["text"]})
return history, gr.MultimodalTextbox(value=None, interactive=False)
def get_dropdown_choices(history):
"""Generate all dropdown choices at once"""
replay_choices = [(f"π {text[:30]}{'...' if len(text) > 30 else ''}", msg_id)
for msg_id, text in bot_instance.message_texts.items()]
improve_choices = [(f"Message {i+1}: {msg['content'].replace('π€ ', '')[:50]}{'...' if len(msg['content'].replace('π€ ', '')) > 50 else ''}", i)
for i, msg in enumerate(history) if msg["role"] == "user"]
translate_choices = [(f"Bot {i+1}: {msg['content'][:50]}{'...' if len(msg['content']) > 50 else ''}", i)
for i, msg in enumerate(history) if msg["role"] == "assistant"]
return replay_choices, improve_choices, translate_choices
def bot(history):
if not history:
return history, None, gr.Dropdown(choices=[]), gr.Dropdown(choices=[]), gr.Dropdown(choices=[])
last_message = history[-1]["content"]
user_text = last_message.replace("π€ ", "") if last_message.startswith("π€ ") else last_message
response_text = bot_instance.generate_response(user_text)
message_id = f"msg_{len(history)}_{int(time.time())}"
history.append({"role": "assistant", "content": response_text})
audio_file = bot_instance.synthesize_speech(response_text)
if audio_file:
bot_instance.message_audio[message_id] = audio_file
bot_instance.message_texts[message_id] = response_text
replay_choices, improve_choices, translate_choices = get_dropdown_choices(history)
return history, audio_file, gr.Dropdown(choices=replay_choices), gr.Dropdown(choices=improve_choices), gr.Dropdown(choices=translate_choices)
def improve_message_grammar(history, message_index):
if not history or message_index < 0 or message_index >= len(history) or history[message_index]["role"] != "user":
return history, ""
original_text = history[message_index]["content"]
prefix = "π€ " if original_text.startswith("π€ ") else ""
text_to_improve = original_text.replace("π€ ", "")
improvement_result = bot_instance.improve_latin_grammar(text_to_improve)
corrected_text = improvement_result["corrected"]
explanation = improvement_result["explanation"]
if corrected_text and corrected_text != text_to_improve:
history[message_index]["content"] = f"{prefix}{corrected_text} β¨"
return history, explanation
def clear_all_data():
bot_instance.message_audio.clear()
bot_instance.message_texts.clear()
# Also clean up models to free memory
bot_instance._cleanup_models()
print("All data and models cleared from memory")
return [], None, gr.Dropdown(choices=[]), gr.Dropdown(choices=[]), gr.Dropdown(choices=[])
# Initialize the bot instance early
print("π Initializing Latin Conversation Bot...")
bot_instance = LatinConversationBot()
with gr.Blocks(title="ποΈ Latin Conversation Bot", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# ποΈ Latin Conversation Bot
Speak or type in Latin for AI-powered conversations with speech synthesis and grammar improvement!
""")
chatbot = gr.Chatbot(type="messages", height=400, show_label=False)
chat_input = gr.MultimodalTextbox(
interactive=True, file_types=["audio"], placeholder="π€ Record or type in Latin...",
show_label=False, sources=["microphone", "upload"]
)
with gr.Row():
audio_output = gr.Audio(label="π Bot Response", autoplay=True, scale=2)
replay_dropdown = gr.Dropdown(label="π Replay Message", choices=[], scale=1)
with gr.Row():
improve_dropdown = gr.Dropdown(label="β¨ Select Message to Improve", choices=[], scale=2)
improve_btn = gr.Button("β¨ Improve Grammar", size="sm", variant="secondary", scale=1)
grammar_explanation = gr.Textbox(label="π Grammar Explanation", interactive=False, visible=False)
with gr.Row():
translate_dropdown = gr.Dropdown(label="π Select Bot Message to Translate", choices=[], scale=2)
language_dropdown = gr.Dropdown(
label="Target Language",
choices=["English", "Spanish", "French", "German", "Italian", "Portuguese", "Chinese", "Japanese"],
value="English",
scale=1
)
translate_btn = gr.Button("π Translate", size="sm", variant="secondary", scale=1)
translation_output = gr.Textbox(label="π Translation", interactive=False, visible=False)
clear_btn = gr.Button("ποΈ Clear", size="sm")
# Event handlers
chat_msg = chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input])
bot_msg = chat_msg.then(bot, chatbot, [chatbot, audio_output, replay_dropdown, improve_dropdown, translate_dropdown])
bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])
replay_dropdown.change(
lambda msg_id: bot_instance.message_audio.get(msg_id) if msg_id else None,
inputs=[replay_dropdown], outputs=[audio_output]
)
clear_btn.click(clear_all_data, outputs=[chatbot, audio_output, replay_dropdown, improve_dropdown, translate_dropdown])
def improve_selected_message(history, selected_index):
if selected_index is None:
_, improve_choices, _ = get_dropdown_choices(history)
return history, gr.Dropdown(choices=improve_choices), gr.Textbox(visible=False)
improved_history, explanation = improve_message_grammar(history, selected_index)
_, improve_choices, _ = get_dropdown_choices(improved_history)
show_explanation = explanation and explanation != "No corrections needed."
return improved_history, gr.Dropdown(choices=improve_choices), gr.Textbox(value=explanation if show_explanation else "", visible=show_explanation)
def translate_selected_message(history, selected_index, target_language):
if selected_index is None or not history or selected_index >= len(history) or history[selected_index]["role"] != "assistant":
return gr.Textbox(visible=False)
latin_text = history[selected_index]["content"]
translation = bot_instance.translate_latin(latin_text, target_language)
return gr.Textbox(value=f"Original: {latin_text}\n\n{target_language}: {translation}", visible=True)
improve_btn.click(improve_selected_message, [chatbot, improve_dropdown], [chatbot, improve_dropdown, grammar_explanation])
translate_btn.click(translate_selected_message, [chatbot, translate_dropdown, language_dropdown], [translation_output])
if __name__ == "__main__":
# Launch with optimized settings for HF Spaces
demo.launch(
server_port=7860, # Standard HF Spaces port
share=False,
show_error=True,
quiet=False # Show startup logs
) |