|
|
|
|
|
"""voc6.ipynb |
|
|
|
|
|
Automatically generated by Colab. |
|
|
|
|
|
Original file is located at |
|
|
https://colab.research.google.com/drive/17WecCovbP3TgYvHDyZ4Yckj77r2q5Nam |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
import json |
|
|
import re |
|
|
from typing import List, Dict, Tuple |
|
|
import numpy as np |
|
|
import faiss |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from dataclasses import dataclass |
|
|
import pickle |
|
|
import os |
|
|
import io |
|
|
from typing import Optional |
|
|
from spitch import Spitch |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import tempfile |
|
|
import os |
|
|
import atexit |
|
|
import glob |
|
|
import io |
|
|
from typing import Optional |
|
|
from spitch import Spitch |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SpitchVoiceHandler: |
|
|
""" |
|
|
Handles all voice-related operations using Spitch API. |
|
|
Supports multilingual speech-to-text and text-to-speech. |
|
|
""" |
|
|
|
|
|
def __init__(self, api_key: str): |
|
|
""" |
|
|
Initialize Spitch client. |
|
|
|
|
|
Args: |
|
|
api_key: Your Spitch API key |
|
|
""" |
|
|
self.client = Spitch(api_key=api_key) |
|
|
|
|
|
def transcribe_audio( |
|
|
self, |
|
|
audio_file, |
|
|
source_language: str = "en", |
|
|
model: str = "mansa_v1" |
|
|
) -> str: |
|
|
""" |
|
|
Transcribe audio to text using Spitch. |
|
|
Supports multiple African and international languages. |
|
|
|
|
|
Args: |
|
|
audio_file: Audio file path or file-like object |
|
|
source_language: Language code (e.g., 'en', 'yo', 'ig', 'ha') |
|
|
model: Spitch model to use (default: mansa_v1) |
|
|
|
|
|
Returns: |
|
|
Transcribed text |
|
|
""" |
|
|
try: |
|
|
print(f"🎤 Transcribing audio file: {audio_file}") |
|
|
|
|
|
|
|
|
if isinstance(audio_file, str): |
|
|
with open(audio_file, 'rb') as f: |
|
|
response = self.client.speech.transcribe( |
|
|
content=f, |
|
|
language=source_language, |
|
|
model=model |
|
|
) |
|
|
else: |
|
|
|
|
|
response = self.client.speech.transcribe( |
|
|
content=audio_file, |
|
|
language=source_language, |
|
|
model=model |
|
|
) |
|
|
|
|
|
print(f"Response type: {type(response)}") |
|
|
|
|
|
|
|
|
if hasattr(response, 'text') and callable(response.text): |
|
|
|
|
|
transcription_text = response.text() |
|
|
elif hasattr(response, 'text'): |
|
|
|
|
|
transcription_text = response.text |
|
|
elif hasattr(response, 'json'): |
|
|
|
|
|
json_data = response.json() |
|
|
transcription_text = json_data.get('text', str(json_data)) |
|
|
else: |
|
|
|
|
|
transcription_text = str(response) |
|
|
|
|
|
print(f"✅ Transcription: {transcription_text}") |
|
|
return transcription_text |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Transcription error: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return f"Sorry, I couldn't understand the audio. Error: {str(e)}" |
|
|
|
|
|
def translate_to_english(self, text: str, source_lang: str = "auto") -> str: |
|
|
""" |
|
|
Translate text to English using Spitch translation API. |
|
|
|
|
|
Args: |
|
|
text: Text to translate |
|
|
source_lang: Source language code or 'auto' for auto-detection |
|
|
|
|
|
Returns: |
|
|
Translated text in English |
|
|
""" |
|
|
try: |
|
|
|
|
|
if source_lang == "en": |
|
|
return text |
|
|
|
|
|
translation = self.client.text.translate( |
|
|
text=text, |
|
|
source=source_lang, |
|
|
target="en" |
|
|
) |
|
|
return translation.text |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Translation error: {e}") |
|
|
return text |
|
|
|
|
|
def synthesize_speech( |
|
|
self, |
|
|
text: str, |
|
|
target_language: str = "en", |
|
|
voice: str = "lina" |
|
|
) -> bytes: |
|
|
""" |
|
|
Convert text to speech using Spitch TTS. |
|
|
|
|
|
Args: |
|
|
text: Text to convert to speech |
|
|
target_language: Target language for speech |
|
|
voice: Voice to use (e.g., 'lina', 'ada', 'kofi') |
|
|
|
|
|
Returns: |
|
|
Audio bytes |
|
|
""" |
|
|
try: |
|
|
|
|
|
response = self.client.speech.generate( |
|
|
text=text, |
|
|
language=target_language, |
|
|
voice=voice |
|
|
) |
|
|
|
|
|
|
|
|
if hasattr(response, 'read'): |
|
|
audio_bytes = response.read() |
|
|
print(f"✅ TTS generated {len(audio_bytes)} bytes of audio") |
|
|
return audio_bytes |
|
|
else: |
|
|
print(f"❌ Response type: {type(response)}") |
|
|
print(f"❌ Response attributes: {dir(response)}") |
|
|
return None |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ TTS error: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class WemaVoiceAssistant: |
|
|
""" |
|
|
Complete voice-enabled assistant combining Spitch voice I/O |
|
|
with your existing Wema RAG system. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
rag_system, |
|
|
chain, |
|
|
spitch_api_key: str |
|
|
): |
|
|
""" |
|
|
Initialize the voice assistant. |
|
|
|
|
|
Args: |
|
|
rag_system: Your initialized WemaRAGSystem |
|
|
chain: Your LangChain RAG chain (already created) |
|
|
spitch_api_key: Spitch API key |
|
|
""" |
|
|
self.rag_system = rag_system |
|
|
self.voice_handler = SpitchVoiceHandler(spitch_api_key) |
|
|
self.chain = chain |
|
|
|
|
|
def process_voice_query( |
|
|
self, |
|
|
audio_input, |
|
|
input_language: str = "en", |
|
|
output_language: str = "en", |
|
|
voice: str = "lina" |
|
|
): |
|
|
""" |
|
|
Complete voice interaction pipeline: |
|
|
1. Speech to text (any language) |
|
|
2. Translate to English if needed |
|
|
3. Query RAG system |
|
|
4. Generate response |
|
|
5. Translate response if needed |
|
|
6. Text to speech |
|
|
|
|
|
Args: |
|
|
audio_input: Audio file from user |
|
|
input_language: User's spoken language |
|
|
output_language: Desired response language |
|
|
voice: TTS voice to use |
|
|
|
|
|
Returns: |
|
|
tuple: (response_text, response_audio) |
|
|
""" |
|
|
try: |
|
|
|
|
|
print(f"Transcribing audio in {input_language}...") |
|
|
transcribed_text = self.voice_handler.transcribe_audio( |
|
|
audio_input, |
|
|
source_language=input_language |
|
|
) |
|
|
print(f"Transcribed: {transcribed_text}") |
|
|
|
|
|
|
|
|
if input_language != "en": |
|
|
print("Translating to English...") |
|
|
english_query = self.voice_handler.translate_to_english( |
|
|
transcribed_text, |
|
|
source_lang=input_language |
|
|
) |
|
|
else: |
|
|
english_query = transcribed_text |
|
|
|
|
|
print(f"English query: {english_query}") |
|
|
|
|
|
|
|
|
print("Querying RAG system...") |
|
|
response_text = self.chain.invoke({"query": english_query}) |
|
|
print(f"RAG response: {response_text[:100]}...") |
|
|
|
|
|
|
|
|
if output_language != "en": |
|
|
print(f"Translating response to {output_language}...") |
|
|
translation = self.voice_handler.client.text.translate( |
|
|
text=response_text, |
|
|
source="en", |
|
|
target=output_language |
|
|
) |
|
|
final_text = translation.text |
|
|
else: |
|
|
final_text = response_text |
|
|
|
|
|
|
|
|
print("Generating speech...") |
|
|
audio_response = self.voice_handler.synthesize_speech( |
|
|
final_text, |
|
|
target_language=output_language, |
|
|
voice=voice |
|
|
) |
|
|
|
|
|
return final_text, audio_response |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"An error occurred: {str(e)}" |
|
|
print(error_msg) |
|
|
return error_msg, None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_audio_to_temp_file(audio_bytes): |
|
|
"""Save audio bytes to a temporary file and return the path.""" |
|
|
if audio_bytes is None: |
|
|
return None |
|
|
|
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') |
|
|
temp_file.write(audio_bytes) |
|
|
temp_file.close() |
|
|
|
|
|
return temp_file.name |
|
|
|
|
|
|
|
|
def cleanup_temp_audio_files(): |
|
|
"""Clean up temporary audio files on exit.""" |
|
|
temp_dir = tempfile.gettempdir() |
|
|
for temp_file in glob.glob(os.path.join(temp_dir, "tmp*.mp3")): |
|
|
try: |
|
|
os.remove(temp_file) |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
atexit.register(cleanup_temp_audio_files) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_voice_gradio_interface( |
|
|
rag_system, |
|
|
chain, |
|
|
spitch_api_key: str |
|
|
): |
|
|
""" |
|
|
Create a Gradio interface with BOTH text and voice input/output capabilities. |
|
|
|
|
|
Args: |
|
|
rag_system: Your initialized WemaRAGSystem |
|
|
chain: Your LangChain RAG chain (already created) |
|
|
spitch_api_key: Spitch API key |
|
|
|
|
|
Returns: |
|
|
Gradio Interface |
|
|
""" |
|
|
|
|
|
|
|
|
assistant = WemaVoiceAssistant(rag_system, chain, spitch_api_key) |
|
|
|
|
|
|
|
|
LANGUAGE_CONFIG = { |
|
|
"English": { |
|
|
"code": "en", |
|
|
"voices": ["john", "lucy", "lina", "jude", "henry", "kani", "kingsley", |
|
|
"favour", "comfort", "daniel", "remi"] |
|
|
}, |
|
|
"Yoruba": { |
|
|
"code": "yo", |
|
|
"voices": ["sade", "funmi", "segun", "femi"] |
|
|
}, |
|
|
"Igbo": { |
|
|
"code": "ig", |
|
|
"voices": ["obinna", "ngozi", "amara", "ebuka"] |
|
|
}, |
|
|
"Hausa": { |
|
|
"code": "ha", |
|
|
"voices": ["hasan", "amina", "zainab", "aliyu"] |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
ALL_LANGUAGES = list(LANGUAGE_CONFIG.keys()) |
|
|
|
|
|
|
|
|
|
|
|
VOICES = ["lina", "ada", "kofi"] |
|
|
|
|
|
def handle_text_query(text_input): |
|
|
"""Handle text-only queries.""" |
|
|
if not text_input or text_input.strip() == "": |
|
|
return "Please enter a question.", None |
|
|
|
|
|
try: |
|
|
response = chain.invoke({"query": text_input}) |
|
|
return response, None |
|
|
except Exception as e: |
|
|
return f"Error: {str(e)}", None |
|
|
|
|
|
def update_voices(language): |
|
|
"""Update voice dropdown based on selected language.""" |
|
|
voices = LANGUAGE_CONFIG.get(language, {}).get("voices", ["lina"]) |
|
|
return gr.Dropdown(choices=voices, value=voices[0]) |
|
|
|
|
|
def handle_voice_interaction(audio, input_lang, output_lang, voice): |
|
|
"""Gradio handler function for voice - FIXED VERSION.""" |
|
|
print("="*60) |
|
|
print("VOICE INTERACTION STARTED") |
|
|
print(f"Audio input: {audio}") |
|
|
print(f"Input language: {input_lang}") |
|
|
print(f"Output language: {output_lang}") |
|
|
print(f"Voice: {voice}") |
|
|
print("="*60) |
|
|
|
|
|
if audio is None: |
|
|
return "Please record or upload audio.", None |
|
|
|
|
|
|
|
|
input_config = LANGUAGE_CONFIG.get(input_lang, LANGUAGE_CONFIG["English"]) |
|
|
output_config = LANGUAGE_CONFIG.get(output_lang, LANGUAGE_CONFIG["English"]) |
|
|
|
|
|
input_code = input_config["code"] |
|
|
output_code = output_config["code"] |
|
|
|
|
|
|
|
|
available_voices = output_config["voices"] |
|
|
if voice not in available_voices: |
|
|
voice = available_voices[0] |
|
|
print(f"⚠️ Voice changed to {voice} for {output_lang}") |
|
|
|
|
|
try: |
|
|
|
|
|
print("\n🎤 Processing voice query...") |
|
|
|
|
|
|
|
|
transcribed_text = assistant.voice_handler.transcribe_audio( |
|
|
audio, |
|
|
source_language=input_code |
|
|
) |
|
|
print(f"📝 Transcribed: {transcribed_text}") |
|
|
|
|
|
|
|
|
if input_code != "en": |
|
|
print("🌍 Translating to English...") |
|
|
english_query = assistant.voice_handler.translate_to_english( |
|
|
transcribed_text, |
|
|
source_lang=input_code |
|
|
) |
|
|
else: |
|
|
english_query = transcribed_text |
|
|
|
|
|
print(f"🇬🇧 English query: {english_query}") |
|
|
|
|
|
|
|
|
print("🔍 Querying RAG system...") |
|
|
response_text = assistant.chain.invoke({"query": english_query}) |
|
|
print(f"✅ RAG response: {response_text[:100]}...") |
|
|
|
|
|
|
|
|
if output_code != "en": |
|
|
print(f"🌍 Translating response to {output_lang}...") |
|
|
try: |
|
|
translation = assistant.voice_handler.client.text.translate( |
|
|
text=response_text, |
|
|
source="en", |
|
|
target=output_code |
|
|
) |
|
|
final_text = translation.text |
|
|
except Exception as e: |
|
|
print(f"⚠️ Translation failed: {e}, using English") |
|
|
final_text = response_text |
|
|
else: |
|
|
final_text = response_text |
|
|
|
|
|
|
|
|
print(f"🔊 Generating speech in {output_lang} with voice {voice}...") |
|
|
audio_bytes = assistant.voice_handler.synthesize_speech( |
|
|
final_text, |
|
|
target_language=output_code, |
|
|
voice=voice |
|
|
) |
|
|
|
|
|
print(f"🔊 Audio bytes type: {type(audio_bytes)}") |
|
|
print(f"🔊 Audio bytes length: {len(audio_bytes) if audio_bytes else 0}") |
|
|
|
|
|
|
|
|
audio_file_path = None |
|
|
if audio_bytes: |
|
|
print("\n💾 Saving audio to temp file...") |
|
|
audio_file_path = save_audio_to_temp_file(audio_bytes) |
|
|
print(f"✅ Audio saved to: {audio_file_path}") |
|
|
|
|
|
|
|
|
if audio_file_path and os.path.exists(audio_file_path): |
|
|
file_size = os.path.getsize(audio_file_path) |
|
|
print(f"✅ File size: {file_size} bytes") |
|
|
else: |
|
|
print("❌ File was not created properly!") |
|
|
else: |
|
|
print("❌ No audio bytes received from TTS") |
|
|
|
|
|
print("="*60) |
|
|
return final_text, audio_file_path |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Error processing voice: {str(e)}" |
|
|
print(f"\n❌ ERROR: {error_msg}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
print("="*60) |
|
|
return error_msg, None |
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown(""" |
|
|
# 🏦 Wema Bank AI Assistant |
|
|
### Powered by Spitch AI & LangChain RAG |
|
|
|
|
|
Choose how you want to interact: Type or Speak! |
|
|
""") |
|
|
|
|
|
with gr.Tabs(): |
|
|
|
|
|
with gr.Tab("💬 Text Chat"): |
|
|
gr.Markdown("### Type your banking questions") |
|
|
|
|
|
text_input = gr.Textbox( |
|
|
label="Your Question", |
|
|
placeholder="Ask me anything about Wema Bank products and services...", |
|
|
lines=3 |
|
|
) |
|
|
|
|
|
text_submit_btn = gr.Button("📤 Send", variant="primary", size="lg") |
|
|
|
|
|
text_output = gr.Textbox( |
|
|
label="Response", |
|
|
lines=10, |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["What is ALAT?"], |
|
|
["How do I open a savings account?"], |
|
|
["Tell me about Wema Kiddies Account"], |
|
|
["How can I avoid phishing scams?"], |
|
|
["What loans does Wema Bank offer?"] |
|
|
], |
|
|
inputs=text_input, |
|
|
label="💡 Try these questions" |
|
|
) |
|
|
|
|
|
text_submit_btn.click( |
|
|
fn=handle_text_query, |
|
|
inputs=text_input, |
|
|
outputs=[text_output, gr.Audio(visible=False)] |
|
|
) |
|
|
|
|
|
|
|
|
text_input.submit( |
|
|
fn=handle_text_query, |
|
|
inputs=text_input, |
|
|
outputs=[text_output, gr.Audio(visible=False)] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("🎤 Voice Chat"): |
|
|
gr.Markdown(""" |
|
|
### Speak your banking questions in your language! |
|
|
|
|
|
**✅ Fully Supported Nigerian Languages:** |
|
|
- 🇬🇧 **English** - 11 voices available |
|
|
- 🇳🇬 **Yoruba** - 4 voices (Sade, Funmi, Segun, Femi) |
|
|
- 🇳🇬 **Igbo** - 4 voices (Obinna, Ngozi, Amara, Ebuka) |
|
|
- 🇳🇬 **Hausa** - 4 voices (Hasan, Amina, Zainab, Aliyu) |
|
|
|
|
|
Speak naturally and get responses in both text and audio in your preferred language! |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
audio_input = gr.Audio( |
|
|
sources=["microphone", "upload"], |
|
|
type="filepath", |
|
|
label="🎙️ Record or Upload Audio" |
|
|
) |
|
|
|
|
|
input_language = gr.Dropdown( |
|
|
choices=ALL_LANGUAGES, |
|
|
value="English", |
|
|
label="Your Language (Speech Input)" |
|
|
) |
|
|
|
|
|
with gr.Column(): |
|
|
output_language = gr.Dropdown( |
|
|
choices=ALL_LANGUAGES, |
|
|
value="English", |
|
|
label="Response Language (Audio Output)" |
|
|
) |
|
|
|
|
|
voice_selection = gr.Dropdown( |
|
|
choices=LANGUAGE_CONFIG["English"]["voices"], |
|
|
value="lina", |
|
|
label="Voice" |
|
|
) |
|
|
|
|
|
|
|
|
output_language.change( |
|
|
fn=update_voices, |
|
|
inputs=output_language, |
|
|
outputs=voice_selection |
|
|
) |
|
|
|
|
|
voice_submit_btn = gr.Button("🚀 Ask Wema Assist", variant="primary", size="lg") |
|
|
|
|
|
voice_text_output = gr.Textbox( |
|
|
label="📝 Text Response", |
|
|
lines=8, |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
voice_audio_output = gr.Audio( |
|
|
label="🔊 Audio Response", |
|
|
type="filepath" |
|
|
) |
|
|
|
|
|
voice_submit_btn.click( |
|
|
fn=handle_voice_interaction, |
|
|
inputs=[audio_input, input_language, output_language, voice_selection], |
|
|
outputs=[voice_text_output, voice_audio_output] |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
--- |
|
|
### 📌 Features |
|
|
- **Text Chat**: Fast and simple - just type and get instant responses |
|
|
- **Voice Chat**: Full support for Nigerian languages! |
|
|
|
|
|
### 🇳🇬 Supported Nigerian Languages |
|
|
✅ **English** - 11 different voices (male & female) |
|
|
✅ **Yoruba** - E ku ọjọ! (4 authentic Yoruba voices) |
|
|
✅ **Igbo** - Nnọọ! (4 authentic Igbo voices) |
|
|
✅ **Hausa** - Sannu! (4 authentic Hausa voices) |
|
|
|
|
|
💡 **All features work in every language:** |
|
|
- 🎤 Speak your question in your language |
|
|
- 📝 Get text response translated |
|
|
- 🔊 Hear authentic audio response in your language |
|
|
- 🔄 Seamless translation between languages |
|
|
""") |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_hybrid_interface( |
|
|
rag_system, |
|
|
chain, |
|
|
spitch_api_key: str |
|
|
): |
|
|
""" |
|
|
Creates a simpler interface supporting both text and voice input. |
|
|
|
|
|
Args: |
|
|
rag_system: Your initialized WemaRAGSystem |
|
|
chain: Your LangChain RAG chain (already created) |
|
|
spitch_api_key: Spitch API key |
|
|
|
|
|
Returns: |
|
|
Gradio Interface |
|
|
""" |
|
|
|
|
|
assistant = WemaVoiceAssistant(rag_system, chain, spitch_api_key) |
|
|
|
|
|
def handle_text_query(text_input): |
|
|
"""Handle text-only query.""" |
|
|
try: |
|
|
response = chain.invoke({"query": text_input}) |
|
|
return response, None |
|
|
except Exception as e: |
|
|
return f"Error: {str(e)}", None |
|
|
|
|
|
def handle_voice_query(audio, input_lang, output_lang, voice): |
|
|
"""Handle voice query.""" |
|
|
if audio is None: |
|
|
return "Please provide audio input.", None |
|
|
|
|
|
LANGUAGES = { |
|
|
"English": "en", |
|
|
"Yoruba": "yo", |
|
|
"Igbo": "ig", |
|
|
"Hausa": "ha" |
|
|
} |
|
|
|
|
|
input_code = LANGUAGES.get(input_lang, "en") |
|
|
output_code = LANGUAGES.get(output_lang, "en") |
|
|
|
|
|
|
|
|
text_response, audio_bytes = assistant.process_voice_query( |
|
|
audio, |
|
|
input_language=input_code, |
|
|
output_language=output_code, |
|
|
voice=voice |
|
|
) |
|
|
|
|
|
|
|
|
audio_file_path = None |
|
|
if audio_bytes: |
|
|
audio_file_path = save_audio_to_temp_file(audio_bytes) |
|
|
|
|
|
return text_response, audio_file_path |
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("# 🏦 Wema Bank AI Assistant") |
|
|
|
|
|
with gr.Tabs(): |
|
|
|
|
|
with gr.Tab("💬 Text Chat"): |
|
|
text_input = gr.Textbox( |
|
|
label="Type your question", |
|
|
placeholder="Ask about Wema Bank products and services..." |
|
|
) |
|
|
text_submit = gr.Button("Send") |
|
|
text_output = gr.Textbox(label="Response", lines=10) |
|
|
|
|
|
text_submit.click( |
|
|
fn=handle_text_query, |
|
|
inputs=text_input, |
|
|
outputs=[text_output, gr.Audio(visible=False)] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("🎤 Voice Chat"): |
|
|
audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath") |
|
|
|
|
|
with gr.Row(): |
|
|
input_lang = gr.Dropdown( |
|
|
["English", "Yoruba", "Igbo", "Hausa"], |
|
|
value="English", |
|
|
label="Input Language" |
|
|
) |
|
|
output_lang = gr.Dropdown( |
|
|
["English", "Yoruba", "Igbo", "Hausa"], |
|
|
value="English", |
|
|
label="Output Language" |
|
|
) |
|
|
voice = gr.Dropdown( |
|
|
["lina", "ada", "kofi"], |
|
|
value="lina", |
|
|
label="Voice" |
|
|
) |
|
|
|
|
|
voice_submit = gr.Button("Ask") |
|
|
voice_text_output = gr.Textbox(label="Response Text", lines=8) |
|
|
voice_audio_output = gr.Audio(label="Audio Response", type="filepath") |
|
|
|
|
|
voice_submit.click( |
|
|
fn=handle_voice_query, |
|
|
inputs=[audio_input, input_lang, output_lang, voice], |
|
|
outputs=[voice_text_output, voice_audio_output] |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
@dataclass |
|
|
class DocumentChunk: |
|
|
"""Represents a chunk of text with metadata.""" |
|
|
text: str |
|
|
metadata: Dict |
|
|
chunk_id: int |
|
|
|
|
|
class WemaDocumentChunker: |
|
|
"""Handles intelligent chunking of Wema Bank documents.""" |
|
|
|
|
|
def __init__(self, chunk_size: int = 800, overlap: int = 150): |
|
|
""" |
|
|
Initialize the chunker. |
|
|
|
|
|
Args: |
|
|
chunk_size: Target size for each chunk in characters |
|
|
overlap: Number of characters to overlap between chunks |
|
|
""" |
|
|
self.chunk_size = chunk_size |
|
|
self.overlap = overlap |
|
|
|
|
|
def identify_sections(self, text: str) -> List[Tuple[str, str]]: |
|
|
""" |
|
|
Identify logical sections in the document. |
|
|
|
|
|
Returns: |
|
|
List of tuples (section_title, section_content) |
|
|
""" |
|
|
sections = [] |
|
|
|
|
|
|
|
|
section_patterns = [ |
|
|
r'(Avoiding Financial and Phishing Scams)', |
|
|
r'(Keeping Your Card.*?Safe)', |
|
|
r'(E-mails and calls from.*?)', |
|
|
r'(Scam Alert Tips)', |
|
|
r'(Guard Yourself)', |
|
|
r'(Bank Verification Number)', |
|
|
r'(Personal Banking)', |
|
|
r'(Business Banking)', |
|
|
r'(Corporate Banking)', |
|
|
r'(.*?Account)', |
|
|
r'(.*?Loan.*?)', |
|
|
] |
|
|
|
|
|
|
|
|
combined_pattern = '|'.join(section_patterns) |
|
|
matches = list(re.finditer(combined_pattern, text, re.IGNORECASE)) |
|
|
|
|
|
if matches: |
|
|
for i, match in enumerate(matches): |
|
|
start = match.start() |
|
|
end = matches[i + 1].start() if i + 1 < len(matches) else len(text) |
|
|
section_title = match.group(0).strip() |
|
|
section_content = text[start:end].strip() |
|
|
sections.append((section_title, section_content)) |
|
|
else: |
|
|
|
|
|
sections.append(("General Content", text)) |
|
|
|
|
|
return sections |
|
|
|
|
|
def chunk_text(self, text: str, metadata: Dict) -> List[DocumentChunk]: |
|
|
""" |
|
|
Chunk text with semantic awareness and overlap. |
|
|
|
|
|
Args: |
|
|
text: Text to chunk |
|
|
metadata: Metadata to attach to chunks |
|
|
|
|
|
Returns: |
|
|
List of DocumentChunk objects |
|
|
""" |
|
|
chunks = [] |
|
|
|
|
|
|
|
|
sections = self.identify_sections(text) |
|
|
|
|
|
chunk_id = 0 |
|
|
for section_title, section_content in sections: |
|
|
|
|
|
if len(section_content) <= self.chunk_size: |
|
|
chunk_metadata = metadata.copy() |
|
|
chunk_metadata['section'] = section_title |
|
|
chunks.append(DocumentChunk( |
|
|
text=section_content, |
|
|
metadata=chunk_metadata, |
|
|
chunk_id=chunk_id |
|
|
)) |
|
|
chunk_id += 1 |
|
|
else: |
|
|
|
|
|
sentences = self._split_into_sentences(section_content) |
|
|
current_chunk = [] |
|
|
current_length = 0 |
|
|
|
|
|
for sentence in sentences: |
|
|
sentence_length = len(sentence) |
|
|
|
|
|
if current_length + sentence_length > self.chunk_size and current_chunk: |
|
|
|
|
|
chunk_text = ' '.join(current_chunk) |
|
|
chunk_metadata = metadata.copy() |
|
|
chunk_metadata['section'] = section_title |
|
|
chunks.append(DocumentChunk( |
|
|
text=chunk_text, |
|
|
metadata=chunk_metadata, |
|
|
chunk_id=chunk_id |
|
|
)) |
|
|
chunk_id += 1 |
|
|
|
|
|
|
|
|
overlap_text = chunk_text[-self.overlap:] if len(chunk_text) > self.overlap else chunk_text |
|
|
overlap_sentences = self._split_into_sentences(overlap_text) |
|
|
current_chunk = overlap_sentences |
|
|
current_length = sum(len(s) for s in current_chunk) |
|
|
|
|
|
current_chunk.append(sentence) |
|
|
current_length += sentence_length |
|
|
|
|
|
|
|
|
if current_chunk: |
|
|
chunk_metadata = metadata.copy() |
|
|
chunk_metadata['section'] = section_title |
|
|
chunks.append(DocumentChunk( |
|
|
text=' '.join(current_chunk), |
|
|
metadata=chunk_metadata, |
|
|
chunk_id=chunk_id |
|
|
)) |
|
|
chunk_id += 1 |
|
|
|
|
|
return chunks |
|
|
|
|
|
def _split_into_sentences(self, text: str) -> List[str]: |
|
|
"""Split text into sentences.""" |
|
|
|
|
|
sentences = re.split(r'(?<=[.!?])\s+', text) |
|
|
return [s.strip() for s in sentences if s.strip()] |
|
|
|
|
|
|
|
|
class WemaRAGSystem: |
|
|
"""Complete RAG system for Wema Bank documents.""" |
|
|
|
|
|
def __init__(self, model_name: str = 'sentence-transformers/all-MiniLM-L6-v2'): |
|
|
""" |
|
|
Initialize the RAG system. |
|
|
|
|
|
Args: |
|
|
model_name: Name of the sentence transformer model to use |
|
|
""" |
|
|
print(f"Loading embedding model: {model_name}") |
|
|
self.model = SentenceTransformer(model_name) |
|
|
self.dimension = self.model.get_sentence_embedding_dimension() |
|
|
self.index = None |
|
|
self.chunks = [] |
|
|
self.chunker = WemaDocumentChunker() |
|
|
|
|
|
def load_and_process_document(self, json_path: str): |
|
|
""" |
|
|
Load JSON document, chunk it, and create embeddings. |
|
|
|
|
|
Args: |
|
|
json_path: Path to the JSON file |
|
|
""" |
|
|
print(f"Loading document from: {json_path}") |
|
|
|
|
|
with open(json_path, 'r', encoding='utf-8') as f: |
|
|
data = json.load(f) |
|
|
|
|
|
|
|
|
all_chunks = [] |
|
|
if isinstance(data, list): |
|
|
documents = data |
|
|
elif isinstance(data, dict): |
|
|
documents = [data] |
|
|
else: |
|
|
raise ValueError("JSON must contain a document object or list of documents") |
|
|
|
|
|
for doc in documents: |
|
|
text = doc.get('text', '') |
|
|
metadata = { |
|
|
'url': doc.get('url', ''), |
|
|
'title': doc.get('title', ''), |
|
|
'meta_description': doc.get('meta_description', '') |
|
|
} |
|
|
|
|
|
|
|
|
chunks = self.chunker.chunk_text(text, metadata) |
|
|
all_chunks.extend(chunks) |
|
|
print(f"Created {len(chunks)} chunks from document: {metadata['title'][:50]}...") |
|
|
|
|
|
self.chunks = all_chunks |
|
|
print(f"Total chunks created: {len(self.chunks)}") |
|
|
|
|
|
|
|
|
self._create_embeddings() |
|
|
|
|
|
def _create_embeddings(self): |
|
|
"""Generate embeddings for all chunks and create FAISS index.""" |
|
|
print("Generating embeddings...") |
|
|
|
|
|
texts = [chunk.text for chunk in self.chunks] |
|
|
embeddings = self.model.encode(texts, show_progress_bar=True) |
|
|
|
|
|
|
|
|
print("Creating FAISS index...") |
|
|
self.index = faiss.IndexFlatL2(self.dimension) |
|
|
self.index.add(embeddings.astype('float32')) |
|
|
|
|
|
print(f"FAISS index created with {self.index.ntotal} vectors") |
|
|
|
|
|
def save(self, index_path: str = 'wema_faiss.index', |
|
|
chunks_path: str = 'wema_chunks.pkl'): |
|
|
""" |
|
|
Save FAISS index and chunks to disk. |
|
|
|
|
|
Args: |
|
|
index_path: Path to save FAISS index |
|
|
chunks_path: Path to save chunks metadata |
|
|
""" |
|
|
if self.index is None: |
|
|
raise ValueError("No index to save. Process documents first.") |
|
|
|
|
|
print(f"Saving FAISS index to: {index_path}") |
|
|
faiss.write_index(self.index, index_path) |
|
|
|
|
|
print(f"Saving chunks metadata to: {chunks_path}") |
|
|
with open(chunks_path, 'wb') as f: |
|
|
pickle.dump(self.chunks, f) |
|
|
|
|
|
print("Save complete!") |
|
|
|
|
|
def load(self, index_path: str = 'wema_faiss.index', |
|
|
chunks_path: str = 'wema_chunks.pkl'): |
|
|
""" |
|
|
Load FAISS index and chunks from disk. |
|
|
|
|
|
Args: |
|
|
index_path: Path to FAISS index |
|
|
chunks_path: Path to chunks metadata |
|
|
""" |
|
|
print(f"Loading FAISS index from: {index_path}") |
|
|
self.index = faiss.read_index(index_path) |
|
|
|
|
|
print(f"Loading chunks metadata from: {chunks_path}") |
|
|
with open(chunks_path, 'rb') as f: |
|
|
self.chunks = pickle.load(f) |
|
|
|
|
|
print(f"Loaded {len(self.chunks)} chunks with index size {self.index.ntotal}") |
|
|
|
|
|
def search(self, query: str, top_k: int = 5) -> List[Dict]: |
|
|
""" |
|
|
Search for relevant chunks given a query. |
|
|
|
|
|
Args: |
|
|
query: Search query |
|
|
top_k: Number of results to return |
|
|
|
|
|
Returns: |
|
|
List of dictionaries containing chunk text, metadata, and similarity score |
|
|
""" |
|
|
if self.index is None: |
|
|
raise ValueError("No index loaded. Load or create an index first.") |
|
|
|
|
|
|
|
|
query_embedding = self.model.encode([query])[0].astype('float32').reshape(1, -1) |
|
|
|
|
|
|
|
|
distances, indices = self.index.search(query_embedding, top_k) |
|
|
|
|
|
|
|
|
results = [] |
|
|
for i, idx in enumerate(indices[0]): |
|
|
chunk = self.chunks[idx] |
|
|
results.append({ |
|
|
'text': chunk.text, |
|
|
'metadata': chunk.metadata, |
|
|
'score': float(distances[0][i]), |
|
|
'chunk_id': chunk.chunk_id |
|
|
}) |
|
|
|
|
|
return results |
|
|
|
|
|
def get_context_for_rag(self, query: str, top_k: int = 3, |
|
|
max_context_length: int = 2000) -> str: |
|
|
""" |
|
|
Get formatted context for RAG applications. |
|
|
|
|
|
Args: |
|
|
query: Search query |
|
|
top_k: Number of chunks to retrieve |
|
|
max_context_length: Maximum length of context to return |
|
|
|
|
|
Returns: |
|
|
Formatted context string |
|
|
""" |
|
|
results = self.search(query, top_k) |
|
|
|
|
|
context_parts = [] |
|
|
current_length = 0 |
|
|
|
|
|
for i, result in enumerate(results, 1): |
|
|
chunk_text = result['text'] |
|
|
section = result['metadata'].get('section', 'N/A') |
|
|
|
|
|
|
|
|
formatted = f"[Source {i} - {section}]\n{chunk_text}\n" |
|
|
|
|
|
if current_length + len(formatted) > max_context_length: |
|
|
break |
|
|
|
|
|
context_parts.append(formatted) |
|
|
current_length += len(formatted) |
|
|
|
|
|
return "\n".join(context_parts) |
|
|
|
|
|
from langchain_core.runnables import RunnablePassthrough, RunnableParallel, RunnableLambda |
|
|
from langchain_core.prompts import ChatPromptTemplate |
|
|
from langchain_core.output_parsers import StrOutputParser |
|
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
|
import gradio as gr |
|
|
from typing import Dict, Any, List |
|
|
import json |
|
|
|
|
|
class WemaDocumentProcessorRunnable: |
|
|
""" |
|
|
Wraps the document loading, chunking, embedding, and storing as a LangChain Runnable. |
|
|
This preserves ALL the original WemaRAGSystem functionality. |
|
|
""" |
|
|
|
|
|
def __init__(self, rag_system): |
|
|
""" |
|
|
Initialize with a WemaRAGSystem instance. |
|
|
|
|
|
Args: |
|
|
rag_system: An initialized WemaRAGSystem object |
|
|
""" |
|
|
self.rag = rag_system |
|
|
|
|
|
|
|
|
self.document_loader = RunnableLambda(self._load_document) |
|
|
self.chunker = RunnableLambda(self._chunk_documents) |
|
|
self.embedder = RunnableLambda(self._create_embeddings) |
|
|
self.storer = RunnableLambda(self._store_index) |
|
|
|
|
|
|
|
|
self.full_pipeline = ( |
|
|
self.document_loader |
|
|
| self.chunker |
|
|
| self.embedder |
|
|
| self.storer |
|
|
) |
|
|
|
|
|
def _load_document(self, inputs: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Loads JSON document(s). |
|
|
|
|
|
Args: |
|
|
inputs: Dictionary with 'json_path' key |
|
|
|
|
|
Returns: |
|
|
Dictionary with loaded documents |
|
|
""" |
|
|
json_path = inputs.get("json_path", inputs) if isinstance(inputs, dict) else inputs |
|
|
|
|
|
print(f"Loading document from: {json_path}") |
|
|
|
|
|
with open(json_path, 'r', encoding='utf-8') as f: |
|
|
data = json.load(f) |
|
|
|
|
|
|
|
|
if isinstance(data, list): |
|
|
documents = data |
|
|
elif isinstance(data, dict): |
|
|
documents = [data] |
|
|
else: |
|
|
raise ValueError("JSON must contain a document object or list of documents") |
|
|
|
|
|
return { |
|
|
"json_path": json_path, |
|
|
"documents": documents, |
|
|
"status": "loaded" |
|
|
} |
|
|
|
|
|
def _chunk_documents(self, inputs: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Chunks documents using WemaDocumentChunker. |
|
|
|
|
|
Args: |
|
|
inputs: Dictionary with 'documents' key |
|
|
|
|
|
Returns: |
|
|
Dictionary with chunked documents |
|
|
""" |
|
|
documents = inputs["documents"] |
|
|
|
|
|
print("Chunking documents...") |
|
|
all_chunks = [] |
|
|
|
|
|
for doc in documents: |
|
|
text = doc.get('text', '') |
|
|
metadata = { |
|
|
'url': doc.get('url', ''), |
|
|
'title': doc.get('title', ''), |
|
|
'meta_description': doc.get('meta_description', '') |
|
|
} |
|
|
|
|
|
|
|
|
chunks = self.rag.chunker.chunk_text(text, metadata) |
|
|
all_chunks.extend(chunks) |
|
|
print(f"Created {len(chunks)} chunks from document: {metadata['title'][:50]}...") |
|
|
|
|
|
self.rag.chunks = all_chunks |
|
|
print(f"Total chunks created: {len(self.rag.chunks)}") |
|
|
|
|
|
return { |
|
|
"json_path": inputs.get("json_path"), |
|
|
"documents": documents, |
|
|
"chunks": all_chunks, |
|
|
"chunk_count": len(all_chunks), |
|
|
"status": "chunked" |
|
|
} |
|
|
|
|
|
def _create_embeddings(self, inputs: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Creates embeddings and FAISS index using the original method. |
|
|
|
|
|
Args: |
|
|
inputs: Dictionary with 'chunks' key |
|
|
|
|
|
Returns: |
|
|
Dictionary with embedding info |
|
|
""" |
|
|
print("Generating embeddings...") |
|
|
|
|
|
|
|
|
self.rag._create_embeddings() |
|
|
|
|
|
return { |
|
|
"json_path": inputs.get("json_path"), |
|
|
"documents": inputs["documents"], |
|
|
"chunks": inputs["chunks"], |
|
|
"chunk_count": inputs["chunk_count"], |
|
|
"index_size": self.rag.index.ntotal, |
|
|
"status": "embedded" |
|
|
} |
|
|
|
|
|
def _store_index(self, inputs: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Saves FAISS index and chunks to disk. |
|
|
|
|
|
Args: |
|
|
inputs: Dictionary with processing results |
|
|
|
|
|
Returns: |
|
|
Dictionary with save status |
|
|
""" |
|
|
index_path = inputs.get("index_path", "wema_faiss.index") |
|
|
chunks_path = inputs.get("chunks_path", "wema_chunks.pkl") |
|
|
|
|
|
|
|
|
self.rag.save(index_path=index_path, chunks_path=chunks_path) |
|
|
|
|
|
return { |
|
|
"json_path": inputs.get("json_path"), |
|
|
"chunk_count": inputs["chunk_count"], |
|
|
"index_size": inputs["index_size"], |
|
|
"index_path": index_path, |
|
|
"chunks_path": chunks_path, |
|
|
"status": "saved" |
|
|
} |
|
|
|
|
|
def get_full_pipeline(self): |
|
|
"""Returns the complete processing pipeline as a LangChain Runnable.""" |
|
|
return self.full_pipeline |
|
|
|
|
|
def get_loader_runnable(self): |
|
|
"""Returns just the document loader.""" |
|
|
return self.document_loader |
|
|
|
|
|
def get_chunker_runnable(self): |
|
|
"""Returns just the chunker.""" |
|
|
return self.chunker |
|
|
|
|
|
def get_embedder_runnable(self): |
|
|
"""Returns just the embedder.""" |
|
|
return self.embedder |
|
|
|
|
|
def get_storer_runnable(self): |
|
|
"""Returns just the storer.""" |
|
|
return self.storer |
|
|
|
|
|
|
|
|
|
|
|
class WemaRAGRetrieverRunnable: |
|
|
""" |
|
|
Wraps the retrieval functionality as a LangChain Runnable. |
|
|
""" |
|
|
|
|
|
def __init__(self, rag_system): |
|
|
""" |
|
|
Initialize with an existing WemaRAGSystem instance. |
|
|
|
|
|
Args: |
|
|
rag_system: An initialized WemaRAGSystem object |
|
|
""" |
|
|
self.rag = rag_system |
|
|
self.retriever = RunnableLambda(self._retrieve_context) |
|
|
|
|
|
def _retrieve_context(self, inputs: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Retrieves context from the RAG system using the original search method. |
|
|
|
|
|
Args: |
|
|
inputs: Dictionary containing 'query' key |
|
|
|
|
|
Returns: |
|
|
Dictionary with query and context |
|
|
""" |
|
|
query = inputs.get("query", inputs) if isinstance(inputs, dict) else inputs |
|
|
|
|
|
|
|
|
context = self.rag.get_context_for_rag(query, top_k=3) |
|
|
|
|
|
return { |
|
|
"query": query, |
|
|
"context": context |
|
|
} |
|
|
|
|
|
def get_retriever_runnable(self): |
|
|
"""Returns the retriever as a LangChain Runnable.""" |
|
|
return self.retriever |
|
|
|
|
|
class WemaRAGLoaderRunnable: |
|
|
""" |
|
|
Wraps the loading functionality as a LangChain Runnable. |
|
|
""" |
|
|
|
|
|
def __init__(self, rag_system): |
|
|
""" |
|
|
Initialize with a WemaRAGSystem instance. |
|
|
|
|
|
Args: |
|
|
rag_system: An initialized WemaRAGSystem object |
|
|
""" |
|
|
self.rag = rag_system |
|
|
self.loader = RunnableLambda(self._load_index) |
|
|
|
|
|
def _load_index(self, inputs: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Loads FAISS index and chunks from disk using the original method. |
|
|
|
|
|
Args: |
|
|
inputs: Dictionary with 'index_path' and 'chunks_path' keys |
|
|
|
|
|
Returns: |
|
|
Dictionary with load status |
|
|
""" |
|
|
index_path = inputs.get("index_path", "wema_faiss.index") |
|
|
chunks_path = inputs.get("chunks_path", "wema_chunks.pkl") |
|
|
|
|
|
|
|
|
self.rag.load(index_path=index_path, chunks_path=chunks_path) |
|
|
|
|
|
return { |
|
|
"index_path": index_path, |
|
|
"chunks_path": chunks_path, |
|
|
"chunk_count": len(self.rag.chunks), |
|
|
"index_size": self.rag.index.ntotal, |
|
|
"status": "loaded" |
|
|
} |
|
|
|
|
|
def get_loader_runnable(self): |
|
|
"""Returns the loader as a LangChain Runnable.""" |
|
|
return self.loader |
|
|
|
|
|
def create_wema_rag_chain(rag_system, google_api_key: str): |
|
|
""" |
|
|
Creates a complete LangChain RAG chain using the WemaRAGSystem. |
|
|
|
|
|
Args: |
|
|
rag_system: An initialized WemaRAGSystem object |
|
|
google_api_key: Google API key for Gemini |
|
|
|
|
|
Returns: |
|
|
A LangChain runnable chain |
|
|
""" |
|
|
|
|
|
|
|
|
wema_retriever = WemaRAGRetrieverRunnable(rag_system) |
|
|
|
|
|
|
|
|
llm = ChatGoogleGenerativeAI( |
|
|
model="gemini-2.0-flash-exp", |
|
|
google_api_key=google_api_key, |
|
|
temperature=0.7 |
|
|
) |
|
|
|
|
|
|
|
|
prompt = ChatPromptTemplate.from_template(""" |
|
|
You are **Wema Assist**, an intelligent, reliable, and customer-friendly virtual assistant for **Wema Bank**. |
|
|
Your goal is to provide accurate, thoughtful, and helpful responses to users — just like a top-tier banking representative — but with the warmth and empathy of a human advisor. |
|
|
You have access to verified Wema Bank documentation provided below as *Context*. |
|
|
|
|
|
Your tone should be: |
|
|
- professional yet conversational, |
|
|
- confident but not pushy, |
|
|
- and focused on solving the user's need while representing Wema Bank's products effectively. |
|
|
|
|
|
--- |
|
|
|
|
|
### 🎯 Core Guidelines: |
|
|
1. **Use the provided "Context" first.** |
|
|
- It contains factual information from official Wema Bank documents. |
|
|
- Do not invent product details — rely on the context where possible. |
|
|
|
|
|
2. **Always connect answers to Wema Bank offerings.** |
|
|
- When relevant, mention products or services such as: |
|
|
- Wema Savings Account |
|
|
- ALAT Digital Bank or ALAT Savings Goals |
|
|
- Wema Kiddies Account |
|
|
- Business or SME Banking |
|
|
- Wema Loans |
|
|
- Wema Security Tips or Scam Alerts |
|
|
- Even if the user query seems general, highlight any Wema product that could help. |
|
|
|
|
|
3. **Be natural and practical.** |
|
|
- Offer useful, step-by-step guidance. |
|
|
- Use phrasing like: |
|
|
- "At Wema Bank, you can..." |
|
|
- "A good option through Wema is..." |
|
|
- "Wema's ALAT platform allows you to..." |
|
|
|
|
|
4. **If the context isn't related to the query:** |
|
|
- Simply give a general, thoughtful answer — *without apologizing or saying the context is irrelevant.* |
|
|
|
|
|
--- |
|
|
|
|
|
### 📘 Information You Have: |
|
|
|
|
|
**Context:** |
|
|
{context} |
|
|
|
|
|
**User Query:** |
|
|
{query} |
|
|
|
|
|
--- |
|
|
|
|
|
### 🧠 Task: |
|
|
Answer the query in a complete, natural, and customer-friendly way — integrating Wema Bank products or services wherever relevant. |
|
|
If the RAG and context are not related, just give a general answer and don't complain. |
|
|
|
|
|
### 💬 Final Answer: |
|
|
""") |
|
|
|
|
|
|
|
|
chain = ( |
|
|
RunnablePassthrough() |
|
|
| wema_retriever.get_retriever_runnable() |
|
|
| prompt |
|
|
| llm |
|
|
| StrOutputParser() |
|
|
) |
|
|
|
|
|
return chain |
|
|
|
|
|
def create_gradio_interface(rag_system, google_api_key: str): |
|
|
""" |
|
|
Creates a Gradio interface using the LangChain RAG chain. |
|
|
|
|
|
Args: |
|
|
rag_system: An initialized WemaRAGSystem object |
|
|
google_api_key: Google API key for Gemini |
|
|
|
|
|
Returns: |
|
|
Gradio Interface object |
|
|
""" |
|
|
|
|
|
|
|
|
chain = create_wema_rag_chain(rag_system, google_api_key) |
|
|
|
|
|
def chat_function(query: str) -> str: |
|
|
"""Wrapper function for Gradio.""" |
|
|
try: |
|
|
response = chain.invoke({"query": query}) |
|
|
return response |
|
|
except Exception as e: |
|
|
return f"An error occurred: {str(e)}" |
|
|
|
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=chat_function, |
|
|
inputs=gr.Textbox( |
|
|
label="Enter your query about Wema Bank:", |
|
|
placeholder="Ask me anything about Wema Bank products and services..." |
|
|
), |
|
|
outputs=gr.Textbox( |
|
|
label="Wema Assist Response:", |
|
|
lines=10 |
|
|
), |
|
|
title="🏦 Wema Bank RAG Chatbot (LangChain Edition)", |
|
|
description="Powered by LangChain and your custom Wema RAG System", |
|
|
theme="soft" |
|
|
) |
|
|
|
|
|
return iface |
|
|
|
|
|
|
|
|
rag = WemaRAGSystem() |
|
|
|
|
|
|
|
|
processor = WemaDocumentProcessorRunnable(rag) |
|
|
|
|
|
|
|
|
result = processor.get_full_pipeline().invoke({ |
|
|
"json_path": "wema_cleaned.json", |
|
|
"index_path": "wema_faiss.index", |
|
|
"chunks_path": "wema_chunks.pkl" |
|
|
}) |
|
|
|
|
|
print(f"Processing complete!") |
|
|
print(f"Chunks created: {result['chunk_count']}") |
|
|
print(f"Index size: {result['index_size']}") |
|
|
print(f"Saved to: {result['index_path']}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
# Cell 4: Create and launch Gradio interface |
|
|
from google.colab import userdata |
|
|
|
|
|
GOOGLE_API_KEY = userdata.get('GOOGLE_API_KEY') |
|
|
iface = create_gradio_interface(rag, GOOGLE_API_KEY) |
|
|
iface.launch() |
|
|
""" |
|
|
|
|
|
''' |
|
|
# Cell 2: Set up your RAG system (your existing code) |
|
|
rag = WemaRAGSystem() |
|
|
rag.load() # Load your existing index |
|
|
|
|
|
# Cell 3: Initialize API keys |
|
|
from google.colab import userdata |
|
|
|
|
|
GOOGLE_API_KEY = userdata.get('GOOGLE_API_KEY') |
|
|
SPITCH_API_KEY = userdata.get('SPITCH_API_KEY') # Add this to your Colab secrets |
|
|
|
|
|
# Cell 4: Launch voice interface |
|
|
iface = create_voice_gradio_interface( |
|
|
rag_system=rag, |
|
|
google_api_key=GOOGLE_API_KEY, |
|
|
spitch_api_key=SPITCH_API_KEY |
|
|
) |
|
|
iface.launch(share=True) |
|
|
''' |
|
|
|
|
|
|
|
|
rag = WemaRAGSystem() |
|
|
rag.load() |
|
|
|
|
|
|
|
|
import os |
|
|
|
|
|
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") |
|
|
SPITCH_API_KEY = os.getenv("SPITCH_API_KEY") |
|
|
|
|
|
if not GOOGLE_API_KEY or not SPITCH_API_KEY: |
|
|
raise ValueError("Missing one or more API keys. Make sure they are added as secrets in your Space.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chain = create_wema_rag_chain(rag, GOOGLE_API_KEY) |
|
|
|
|
|
iface = create_voice_gradio_interface( |
|
|
rag_system=rag, |
|
|
chain=chain, |
|
|
spitch_api_key=SPITCH_API_KEY |
|
|
) |
|
|
|
|
|
iface.launch(share=True, debug=True) |
|
|
|