File size: 17,185 Bytes
163b430
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
import gradio as gr
import time
import torch
import os
import gc
import psutil
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, VitsModel, VitsTokenizer
import soundfile as sf
import librosa
import tempfile
import google.generativeai as genai
from dotenv import load_dotenv

# Try to load .env file as fallback (for local development)
# HF Spaces will use secrets directly, so this won't override them
load_dotenv()

# Set environment variables for optimization
os.environ["TOKENIZERS_PARALLELISM"] = "false"  # Avoid warnings
os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache"  # Use tmp for HF Spaces
os.environ["HF_HOME"] = "/tmp/huggingface"  # Cache location

def get_memory_usage():
    """Get current memory usage in MB"""
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / 1024 / 1024

def log_memory(context=""):
    """Log current memory usage"""
    memory_mb = get_memory_usage()
    print(f"Memory usage {context}: {memory_mb:.1f} MB")

class LatinConversationBot:
    def __init__(self):
        log_memory("at initialization start")
        
        # Force CPU-only to reduce memory usage on Hugging Face Spaces
        self.device = "cpu"
        self.message_audio = {}
        self.message_texts = {}
        
        # Initialize Gemini using HF Spaces secret or .env fallback
        api_key = os.getenv("GEMINI_API_KEY")
        if not api_key:
            # More helpful error message for both HF Spaces and local dev
            raise ValueError(
                "GEMINI_API_KEY not found!\n"
                "For Hugging Face Spaces:\n"
                "  1. Go to your Space settings\n"
                "  2. Click on 'Repository secrets'\n" 
                "  3. Add 'GEMINI_API_KEY' with your API key\n"
                "For Local Development:\n"
                "  1. Create a .env file in the project root\n"
                "  2. Add: GEMINI_API_KEY=your_api_key_here"
            )
        genai.configure(api_key=api_key)
        self.gemini_model = genai.GenerativeModel('gemini-flash-latest')
        
        # Model containers
        self.asr_processor = None
        self.asr_model = None
        self.tts_model = None
        self.tts_tokenizer = None
        self.models_loaded = {"asr": False, "tts": False}
        
        print(f"Bot initialized on device: {self.device}")
        
        # Pre-load models at startup for faster response
        try:
            print("πŸš€ Starting model pre-loading...")
            self._preload_models()
            print("βœ… All models loaded successfully!")
        except Exception as e:
            print(f"⚠️ Model pre-loading failed: {e}")
            print("Models will be loaded on-demand")
        
        log_memory("after initialization")
    
    def _preload_models(self):
        """Pre-load models at startup but manage memory efficiently"""
        try:
            # Load ASR first with optimizations
            print("πŸ“₯ Loading ASR models...")
            self.asr_processor = AutoProcessor.from_pretrained(
                "ken-z/latin_whisper-small",
                cache_dir="/tmp/transformers_cache",
                local_files_only=False
            )
            self.asr_model = AutoModelForSpeechSeq2Seq.from_pretrained(
                "ken-z/latin_whisper-small",
                torch_dtype=torch.float32,
                cache_dir="/tmp/transformers_cache",
                low_cpu_mem_usage=True,  # Optimize memory usage
                local_files_only=False
            ).to(self.device)
            self.models_loaded["asr"] = True
            log_memory("after ASR loading")
            
            # Load TTS with optimizations
            print("🎡 Loading TTS models...")
            self.tts_tokenizer = VitsTokenizer.from_pretrained(
                "Ken-Z/latin_SpeechT5",
                cache_dir="/tmp/transformers_cache",
                local_files_only=False
            )
            self.tts_model = VitsModel.from_pretrained(
                "Ken-Z/latin_SpeechT5",
                torch_dtype=torch.float32,
                cache_dir="/tmp/transformers_cache",
                low_cpu_mem_usage=True,  # Optimize memory usage
                local_files_only=False
            ).to(self.device)
            self.models_loaded["tts"] = True
            log_memory("after TTS loading")
            
        except Exception as e:
            print(f"Error in model loading: {e}")
            # Fallback to lazy loading
            self.models_loaded = {"asr": False, "tts": False}
            raise e
    
    def _ensure_asr_loaded(self):
        """Ensure ASR models are loaded"""
        if not self.models_loaded["asr"]:
            print("Loading ASR models on-demand...")
            self.asr_processor = AutoProcessor.from_pretrained("ken-z/latin_whisper-small")
            self.asr_model = AutoModelForSpeechSeq2Seq.from_pretrained(
                "ken-z/latin_whisper-small", 
                torch_dtype=torch.float32
            ).to(self.device)
            self.models_loaded["asr"] = True
    
    def _ensure_tts_loaded(self):
        """Ensure TTS models are loaded"""
        if not self.models_loaded["tts"]:
            print("Loading TTS models on-demand...")
            self.tts_tokenizer = VitsTokenizer.from_pretrained("Ken-Z/latin_SpeechT5")
            self.tts_model = VitsModel.from_pretrained(
                "Ken-Z/latin_SpeechT5",
                torch_dtype=torch.float32
            ).to(self.device)
            self.models_loaded["tts"] = True
    
    def _cleanup_models(self):
        """Free up memory by clearing unused models"""
        log_memory("before cleanup")
        if self.asr_model is not None:
            del self.asr_model
            self.asr_model = None
            self.models_loaded["asr"] = False
        if self.asr_processor is not None:
            del self.asr_processor
            self.asr_processor = None
        if self.tts_model is not None:
            del self.tts_model
            self.tts_model = None
            self.models_loaded["tts"] = False
        if self.tts_tokenizer is not None:
            del self.tts_tokenizer
            self.tts_tokenizer = None
        gc.collect()
        log_memory("after cleanup")
        print("Models cleaned up from memory")
    
    def transcribe_audio(self, audio_path):
        try:
            # Ensure ASR models are loaded
            self._ensure_asr_loaded()
            
            audio, _ = librosa.load(audio_path, sr=16000)
            input_features = self.asr_processor(audio, sampling_rate=16000, return_tensors="pt").input_features.to(self.device)
            with torch.no_grad():
                predicted_ids = self.asr_model.generate(input_features)
                result = self.asr_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0].strip()
            
            # Clean up tensors but keep models loaded
            del input_features, predicted_ids
            gc.collect()
            
            return result
        except Exception as e:
            print(f"ASR Error: {str(e)}")
            return f"Error: {str(e)}"
    
    def _call_gemini(self, prompt):
        try:
            return self.gemini_model.generate_content(prompt).text.strip()
        except Exception as e:
            print(f"Gemini API error: {e}")
            return "Error: Gemini API not available"
    
    def generate_response(self, text):
        prompt = f"""You are a Latin conversation bot. Respond ONLY in Latin, keep responses to 1-2 sentences, use proper Classical Latin grammar with proper diacritics, and be conversational.

Examples: "Salve" β†’ "Salve! Quid agis hodie?", "Hello" β†’ "Salve! Latine loquere, quaeso!"

User: {text}
Response:"""
        return self._call_gemini(prompt)
    
    def improve_latin_grammar(self, text):
        prompt = f"""Fix Latin grammar, diacritics, and word order. Format:
CORRECTED: [corrected text]
EXPLANATION: [brief explanation of fixes only]

Text: {text}"""
        
        response = self._call_gemini(prompt)
        
        # Parse response
        corrected = explanation = ""
        for line in response.split('\n'):
            if line.startswith("CORRECTED:"):
                corrected = line[10:].strip()
            elif line.startswith("EXPLANATION:"):
                explanation = line[12:].strip()
        
        return {
            "corrected": corrected or text,
            "explanation": explanation or "No explanation provided."
        }
    
    def translate_latin(self, text, target_language):
        prompt = f"""Translate this Latin text to {target_language}. Return ONLY the translation, no explanations.

Latin text: {text}
{target_language} translation:"""
        return self._call_gemini(prompt)
    
    def synthesize_speech(self, text):
        try:
            # Ensure TTS models are loaded
            self._ensure_tts_loaded()
            
            inputs = self.tts_tokenizer(text, return_tensors="pt")
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            with torch.no_grad():
                speech = self.tts_model(**inputs).waveform.squeeze().cpu().numpy()
            
            # Clean up tensors but keep models loaded
            del inputs
            gc.collect()
            
            with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
                sf.write(tmp_file.name, speech, samplerate=16000)
                return tmp_file.name
        except Exception as e:
            print(f"TTS error: {e}")
            return None
    
bot_instance = LatinConversationBot()

def add_message(history, message):
    for file_info in message["files"]:
        file_path = file_info.path if hasattr(file_info, 'path') else file_info
        if file_path.endswith(('.wav', '.mp3', '.m4a', '.ogg', '.flac')):
            transcription = bot_instance.transcribe_audio(file_path)
            history.append({"role": "user", "content": f"🎀 {transcription}"})
    
    if message["text"] and message["text"].strip():
        history.append({"role": "user", "content": message["text"]})
    
    return history, gr.MultimodalTextbox(value=None, interactive=False)

def get_dropdown_choices(history):
    """Generate all dropdown choices at once"""
    replay_choices = [(f"πŸ”Š {text[:30]}{'...' if len(text) > 30 else ''}", msg_id) 
                     for msg_id, text in bot_instance.message_texts.items()]
    improve_choices = [(f"Message {i+1}: {msg['content'].replace('🎀 ', '')[:50]}{'...' if len(msg['content'].replace('🎀 ', '')) > 50 else ''}", i)
                      for i, msg in enumerate(history) if msg["role"] == "user"]
    translate_choices = [(f"Bot {i+1}: {msg['content'][:50]}{'...' if len(msg['content']) > 50 else ''}", i)
                        for i, msg in enumerate(history) if msg["role"] == "assistant"]
    return replay_choices, improve_choices, translate_choices

def bot(history):
    if not history:
        return history, None, gr.Dropdown(choices=[]), gr.Dropdown(choices=[]), gr.Dropdown(choices=[])
    
    last_message = history[-1]["content"]
    user_text = last_message.replace("🎀 ", "") if last_message.startswith("🎀 ") else last_message
    
    response_text = bot_instance.generate_response(user_text)
    message_id = f"msg_{len(history)}_{int(time.time())}"
    
    history.append({"role": "assistant", "content": response_text})
    
    audio_file = bot_instance.synthesize_speech(response_text)
    if audio_file:
        bot_instance.message_audio[message_id] = audio_file
        bot_instance.message_texts[message_id] = response_text
    
    replay_choices, improve_choices, translate_choices = get_dropdown_choices(history)
    return history, audio_file, gr.Dropdown(choices=replay_choices), gr.Dropdown(choices=improve_choices), gr.Dropdown(choices=translate_choices)

def improve_message_grammar(history, message_index):
    if not history or message_index < 0 or message_index >= len(history) or history[message_index]["role"] != "user":
        return history, ""
    
    original_text = history[message_index]["content"]
    prefix = "🎀 " if original_text.startswith("🎀 ") else ""
    text_to_improve = original_text.replace("🎀 ", "")
    
    improvement_result = bot_instance.improve_latin_grammar(text_to_improve)
    corrected_text = improvement_result["corrected"]
    explanation = improvement_result["explanation"]
    
    if corrected_text and corrected_text != text_to_improve:
        history[message_index]["content"] = f"{prefix}{corrected_text} ✨"
    
    return history, explanation

def clear_all_data():
    bot_instance.message_audio.clear()
    bot_instance.message_texts.clear()
    # Also clean up models to free memory
    bot_instance._cleanup_models()
    print("All data and models cleared from memory")
    return [], None, gr.Dropdown(choices=[]), gr.Dropdown(choices=[]), gr.Dropdown(choices=[])

# Initialize the bot instance early
print("πŸš€ Initializing Latin Conversation Bot...")
bot_instance = LatinConversationBot()

with gr.Blocks(title="πŸ›οΈ Latin Conversation Bot", theme=gr.themes.Soft()) as demo:
    gr.Markdown("""
    # πŸ›οΈ Latin Conversation Bot
    Speak or type in Latin for AI-powered conversations with speech synthesis and grammar improvement!
    """)

    
    chatbot = gr.Chatbot(type="messages", height=400, show_label=False)
    
    chat_input = gr.MultimodalTextbox(
        interactive=True, file_types=["audio"], placeholder="🎀 Record or type in Latin...",
        show_label=False, sources=["microphone", "upload"]
    )
    
    with gr.Row():
        audio_output = gr.Audio(label="πŸ”Š Bot Response", autoplay=True, scale=2)
        replay_dropdown = gr.Dropdown(label="πŸ”„ Replay Message", choices=[], scale=1)
    
    with gr.Row():
        improve_dropdown = gr.Dropdown(label="✨ Select Message to Improve", choices=[], scale=2)
        improve_btn = gr.Button("✨ Improve Grammar", size="sm", variant="secondary", scale=1)
    
    grammar_explanation = gr.Textbox(label="πŸ“š Grammar Explanation", interactive=False, visible=False)
    
    with gr.Row():
        translate_dropdown = gr.Dropdown(label="🌍 Select Bot Message to Translate", choices=[], scale=2)
        language_dropdown = gr.Dropdown(
            label="Target Language", 
            choices=["English", "Spanish", "French", "German", "Italian", "Portuguese", "Chinese", "Japanese"], 
            value="English", 
            scale=1
        )
        translate_btn = gr.Button("🌍 Translate", size="sm", variant="secondary", scale=1)
    
    translation_output = gr.Textbox(label="πŸ“ Translation", interactive=False, visible=False)
    
    clear_btn = gr.Button("πŸ—‘οΈ Clear", size="sm")

    # Event handlers
    chat_msg = chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input])
    bot_msg = chat_msg.then(bot, chatbot, [chatbot, audio_output, replay_dropdown, improve_dropdown, translate_dropdown])
    bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])
    
    replay_dropdown.change(
        lambda msg_id: bot_instance.message_audio.get(msg_id) if msg_id else None,
        inputs=[replay_dropdown], outputs=[audio_output]
    )
    
    clear_btn.click(clear_all_data, outputs=[chatbot, audio_output, replay_dropdown, improve_dropdown, translate_dropdown])
    
    def improve_selected_message(history, selected_index):
        if selected_index is None:
            _, improve_choices, _ = get_dropdown_choices(history)
            return history, gr.Dropdown(choices=improve_choices), gr.Textbox(visible=False)
        
        improved_history, explanation = improve_message_grammar(history, selected_index)
        _, improve_choices, _ = get_dropdown_choices(improved_history)
        
        show_explanation = explanation and explanation != "No corrections needed."
        return improved_history, gr.Dropdown(choices=improve_choices), gr.Textbox(value=explanation if show_explanation else "", visible=show_explanation)
    
    def translate_selected_message(history, selected_index, target_language):
        if selected_index is None or not history or selected_index >= len(history) or history[selected_index]["role"] != "assistant":
            return gr.Textbox(visible=False)
        
        latin_text = history[selected_index]["content"]
        translation = bot_instance.translate_latin(latin_text, target_language)
        return gr.Textbox(value=f"Original: {latin_text}\n\n{target_language}: {translation}", visible=True)
    
    improve_btn.click(improve_selected_message, [chatbot, improve_dropdown], [chatbot, improve_dropdown, grammar_explanation])
    translate_btn.click(translate_selected_message, [chatbot, translate_dropdown, language_dropdown], [translation_output])

if __name__ == "__main__":
    # Launch with optimized settings for HF Spaces
    demo.launch(
        server_port=7860,  # Standard HF Spaces port
        share=False,
        show_error=True,
        quiet=False  # Show startup logs
    )