VOC / app.py
mozzicato's picture
ok
b07af54 verified
# -*- coding: utf-8 -*-
"""voc6.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/17WecCovbP3TgYvHDyZ4Yckj77r2q5Nam
"""
# Cell to add FIRST - Your Original WemaRAGSystem
import json
import re
from typing import List, Dict, Tuple
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
from dataclasses import dataclass
import pickle
import os
import io
from typing import Optional
from spitch import Spitch
import gradio as gr
# ============================================================================
# Wema Bank Voice-Enabled RAG Chatbot with Spitch Integration - CORRECTED
# ============================================================================
import tempfile
import os
import atexit
import glob
import io
from typing import Optional
from spitch import Spitch
import gradio as gr
# ============================================================================
# STEP 1: Initialize Spitch Client
# ============================================================================
class SpitchVoiceHandler:
"""
Handles all voice-related operations using Spitch API.
Supports multilingual speech-to-text and text-to-speech.
"""
def __init__(self, api_key: str):
"""
Initialize Spitch client.
Args:
api_key: Your Spitch API key
"""
self.client = Spitch(api_key=api_key)
def transcribe_audio(
self,
audio_file,
source_language: str = "en",
model: str = "mansa_v1"
) -> str:
"""
Transcribe audio to text using Spitch.
Supports multiple African and international languages.
Args:
audio_file: Audio file path or file-like object
source_language: Language code (e.g., 'en', 'yo', 'ig', 'ha')
model: Spitch model to use (default: mansa_v1)
Returns:
Transcribed text
"""
try:
print(f"🎤 Transcribing audio file: {audio_file}")
# If audio_file is a path, open it
if isinstance(audio_file, str):
with open(audio_file, 'rb') as f:
response = self.client.speech.transcribe(
content=f,
language=source_language,
model=model
)
else:
# Assume it's already a file-like object (from Gradio)
response = self.client.speech.transcribe(
content=audio_file,
language=source_language,
model=model
)
print(f"Response type: {type(response)}")
# ✅ Spitch transcribe returns a response object with .text or json()
if hasattr(response, 'text') and callable(response.text):
# It's a method, not an attribute
transcription_text = response.text()
elif hasattr(response, 'text'):
# It's an attribute
transcription_text = response.text
elif hasattr(response, 'json'):
# Try to parse JSON response
json_data = response.json()
transcription_text = json_data.get('text', str(json_data))
else:
# Try to convert response to string
transcription_text = str(response)
print(f"✅ Transcription: {transcription_text}")
return transcription_text
except Exception as e:
print(f"❌ Transcription error: {e}")
import traceback
traceback.print_exc()
return f"Sorry, I couldn't understand the audio. Error: {str(e)}"
def translate_to_english(self, text: str, source_lang: str = "auto") -> str:
"""
Translate text to English using Spitch translation API.
Args:
text: Text to translate
source_lang: Source language code or 'auto' for auto-detection
Returns:
Translated text in English
"""
try:
# If already in English, return as is
if source_lang == "en":
return text
translation = self.client.text.translate(
text=text,
source=source_lang,
target="en"
)
return translation.text
except Exception as e:
print(f"Translation error: {e}")
return text # Return original if translation fails
def synthesize_speech(
self,
text: str,
target_language: str = "en",
voice: str = "lina"
) -> bytes:
"""
Convert text to speech using Spitch TTS.
Args:
text: Text to convert to speech
target_language: Target language for speech
voice: Voice to use (e.g., 'lina', 'ada', 'kofi')
Returns:
Audio bytes
"""
try:
# Call Spitch TTS API
response = self.client.speech.generate(
text=text,
language=target_language,
voice=voice
)
# ✅ FIX: Spitch returns BinaryAPIResponse, use .read() to get bytes
if hasattr(response, 'read'):
audio_bytes = response.read()
print(f"✅ TTS generated {len(audio_bytes)} bytes of audio")
return audio_bytes
else:
print(f"❌ Response type: {type(response)}")
print(f"❌ Response attributes: {dir(response)}")
return None
except Exception as e:
print(f"❌ TTS error: {e}")
import traceback
traceback.print_exc()
return None
# ============================================================================
# STEP 2: Integrate Voice with Your LangChain RAG System
# ============================================================================
class WemaVoiceAssistant:
"""
Complete voice-enabled assistant combining Spitch voice I/O
with your existing Wema RAG system.
"""
def __init__(
self,
rag_system,
chain,
spitch_api_key: str
):
"""
Initialize the voice assistant.
Args:
rag_system: Your initialized WemaRAGSystem
chain: Your LangChain RAG chain (already created)
spitch_api_key: Spitch API key
"""
self.rag_system = rag_system
self.voice_handler = SpitchVoiceHandler(spitch_api_key)
self.chain = chain
def process_voice_query(
self,
audio_input,
input_language: str = "en",
output_language: str = "en",
voice: str = "lina"
):
"""
Complete voice interaction pipeline:
1. Speech to text (any language)
2. Translate to English if needed
3. Query RAG system
4. Generate response
5. Translate response if needed
6. Text to speech
Args:
audio_input: Audio file from user
input_language: User's spoken language
output_language: Desired response language
voice: TTS voice to use
Returns:
tuple: (response_text, response_audio)
"""
try:
# Step 1: Transcribe audio to text
print(f"Transcribing audio in {input_language}...")
transcribed_text = self.voice_handler.transcribe_audio(
audio_input,
source_language=input_language
)
print(f"Transcribed: {transcribed_text}")
# Step 2: Translate to English if not already
if input_language != "en":
print("Translating to English...")
english_query = self.voice_handler.translate_to_english(
transcribed_text,
source_lang=input_language
)
else:
english_query = transcribed_text
print(f"English query: {english_query}")
# Step 3: Get response from RAG system (in English)
print("Querying RAG system...")
response_text = self.chain.invoke({"query": english_query})
print(f"RAG response: {response_text[:100]}...")
# Step 4: Translate response if needed
if output_language != "en":
print(f"Translating response to {output_language}...")
translation = self.voice_handler.client.text.translate(
text=response_text,
source="en",
target=output_language
)
final_text = translation.text
else:
final_text = response_text
# Step 5: Generate speech
print("Generating speech...")
audio_response = self.voice_handler.synthesize_speech(
final_text,
target_language=output_language,
voice=voice
)
return final_text, audio_response
except Exception as e:
error_msg = f"An error occurred: {str(e)}"
print(error_msg)
return error_msg, None
# ============================================================================
# STEP 3: Helper Functions for Audio File Management
# ============================================================================
def save_audio_to_temp_file(audio_bytes):
"""Save audio bytes to a temporary file and return the path."""
if audio_bytes is None:
return None
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp3')
temp_file.write(audio_bytes)
temp_file.close()
return temp_file.name
def cleanup_temp_audio_files():
"""Clean up temporary audio files on exit."""
temp_dir = tempfile.gettempdir()
for temp_file in glob.glob(os.path.join(temp_dir, "tmp*.mp3")):
try:
os.remove(temp_file)
except:
pass
# Register cleanup function to run on exit
atexit.register(cleanup_temp_audio_files)
# ============================================================================
# STEP 4: Create Gradio Interface (With Text AND Voice Options)
# ============================================================================
def create_voice_gradio_interface(
rag_system,
chain,
spitch_api_key: str
):
"""
Create a Gradio interface with BOTH text and voice input/output capabilities.
Args:
rag_system: Your initialized WemaRAGSystem
chain: Your LangChain RAG chain (already created)
spitch_api_key: Spitch API key
Returns:
Gradio Interface
"""
# Initialize voice assistant
assistant = WemaVoiceAssistant(rag_system, chain, spitch_api_key)
# ✅ CORRECT: Exact voice-language mapping from Spitch documentation
LANGUAGE_CONFIG = {
"English": {
"code": "en",
"voices": ["john", "lucy", "lina", "jude", "henry", "kani", "kingsley",
"favour", "comfort", "daniel", "remi"]
},
"Yoruba": {
"code": "yo",
"voices": ["sade", "funmi", "segun", "femi"]
},
"Igbo": {
"code": "ig",
"voices": ["obinna", "ngozi", "amara", "ebuka"]
},
"Hausa": {
"code": "ha",
"voices": ["hasan", "amina", "zainab", "aliyu"]
}
}
# Extract just language names for dropdowns
ALL_LANGUAGES = list(LANGUAGE_CONFIG.keys())
# ✅ FIXED: Only voices that actually exist in Spitch
# Check Spitch docs for exact voice names
VOICES = ["lina", "ada", "kofi"] # Verify these exist
def handle_text_query(text_input):
"""Handle text-only queries."""
if not text_input or text_input.strip() == "":
return "Please enter a question.", None
try:
response = chain.invoke({"query": text_input})
return response, None
except Exception as e:
return f"Error: {str(e)}", None
def update_voices(language):
"""Update voice dropdown based on selected language."""
voices = LANGUAGE_CONFIG.get(language, {}).get("voices", ["lina"])
return gr.Dropdown(choices=voices, value=voices[0])
def handle_voice_interaction(audio, input_lang, output_lang, voice):
"""Gradio handler function for voice - FIXED VERSION."""
print("="*60)
print("VOICE INTERACTION STARTED")
print(f"Audio input: {audio}")
print(f"Input language: {input_lang}")
print(f"Output language: {output_lang}")
print(f"Voice: {voice}")
print("="*60)
if audio is None:
return "Please record or upload audio.", None
# Get language codes and voices
input_config = LANGUAGE_CONFIG.get(input_lang, LANGUAGE_CONFIG["English"])
output_config = LANGUAGE_CONFIG.get(output_lang, LANGUAGE_CONFIG["English"])
input_code = input_config["code"]
output_code = output_config["code"]
# Validate voice for output language
available_voices = output_config["voices"]
if voice not in available_voices:
voice = available_voices[0]
print(f"⚠️ Voice changed to {voice} for {output_lang}")
try:
# Process voice query
print("\n🎤 Processing voice query...")
# Step 1: Transcribe (supports more languages)
transcribed_text = assistant.voice_handler.transcribe_audio(
audio,
source_language=input_code
)
print(f"📝 Transcribed: {transcribed_text}")
# Step 2: Translate to English if needed
if input_code != "en":
print("🌍 Translating to English...")
english_query = assistant.voice_handler.translate_to_english(
transcribed_text,
source_lang=input_code
)
else:
english_query = transcribed_text
print(f"🇬🇧 English query: {english_query}")
# Step 3: Get RAG response
print("🔍 Querying RAG system...")
response_text = assistant.chain.invoke({"query": english_query})
print(f"✅ RAG response: {response_text[:100]}...")
# Step 4: Translate response text if needed
if output_code != "en":
print(f"🌍 Translating response to {output_lang}...")
try:
translation = assistant.voice_handler.client.text.translate(
text=response_text,
source="en",
target=output_code
)
final_text = translation.text
except Exception as e:
print(f"⚠️ Translation failed: {e}, using English")
final_text = response_text
else:
final_text = response_text
# Step 5: Generate speech in the target language with correct voice
print(f"🔊 Generating speech in {output_lang} with voice {voice}...")
audio_bytes = assistant.voice_handler.synthesize_speech(
final_text,
target_language=output_code,
voice=voice
)
print(f"🔊 Audio bytes type: {type(audio_bytes)}")
print(f"🔊 Audio bytes length: {len(audio_bytes) if audio_bytes else 0}")
# ✅ FIX: Convert audio bytes to file path
audio_file_path = None
if audio_bytes:
print("\n💾 Saving audio to temp file...")
audio_file_path = save_audio_to_temp_file(audio_bytes)
print(f"✅ Audio saved to: {audio_file_path}")
# Verify file exists and has content
if audio_file_path and os.path.exists(audio_file_path):
file_size = os.path.getsize(audio_file_path)
print(f"✅ File size: {file_size} bytes")
else:
print("❌ File was not created properly!")
else:
print("❌ No audio bytes received from TTS")
print("="*60)
return final_text, audio_file_path
except Exception as e:
error_msg = f"Error processing voice: {str(e)}"
print(f"\n❌ ERROR: {error_msg}")
import traceback
traceback.print_exc()
print("="*60)
return error_msg, None
# Create Gradio interface with BOTH text and voice
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🏦 Wema Bank AI Assistant
### Powered by Spitch AI & LangChain RAG
Choose how you want to interact: Type or Speak!
""")
with gr.Tabs():
# TEXT TAB
with gr.Tab("💬 Text Chat"):
gr.Markdown("### Type your banking questions")
text_input = gr.Textbox(
label="Your Question",
placeholder="Ask me anything about Wema Bank products and services...",
lines=3
)
text_submit_btn = gr.Button("📤 Send", variant="primary", size="lg")
text_output = gr.Textbox(
label="Response",
lines=10,
interactive=False
)
# Examples for text
gr.Examples(
examples=[
["What is ALAT?"],
["How do I open a savings account?"],
["Tell me about Wema Kiddies Account"],
["How can I avoid phishing scams?"],
["What loans does Wema Bank offer?"]
],
inputs=text_input,
label="💡 Try these questions"
)
text_submit_btn.click(
fn=handle_text_query,
inputs=text_input,
outputs=[text_output, gr.Audio(visible=False)]
)
# Also submit on Enter
text_input.submit(
fn=handle_text_query,
inputs=text_input,
outputs=[text_output, gr.Audio(visible=False)]
)
# VOICE TAB
with gr.Tab("🎤 Voice Chat"):
gr.Markdown("""
### Speak your banking questions in your language!
**✅ Fully Supported Nigerian Languages:**
- 🇬🇧 **English** - 11 voices available
- 🇳🇬 **Yoruba** - 4 voices (Sade, Funmi, Segun, Femi)
- 🇳🇬 **Igbo** - 4 voices (Obinna, Ngozi, Amara, Ebuka)
- 🇳🇬 **Hausa** - 4 voices (Hasan, Amina, Zainab, Aliyu)
Speak naturally and get responses in both text and audio in your preferred language!
""")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(
sources=["microphone", "upload"],
type="filepath",
label="🎙️ Record or Upload Audio"
)
input_language = gr.Dropdown(
choices=ALL_LANGUAGES,
value="English",
label="Your Language (Speech Input)"
)
with gr.Column():
output_language = gr.Dropdown(
choices=ALL_LANGUAGES,
value="English",
label="Response Language (Audio Output)"
)
voice_selection = gr.Dropdown(
choices=LANGUAGE_CONFIG["English"]["voices"],
value="lina",
label="Voice"
)
# Update voices when output language changes
output_language.change(
fn=update_voices,
inputs=output_language,
outputs=voice_selection
)
voice_submit_btn = gr.Button("🚀 Ask Wema Assist", variant="primary", size="lg")
voice_text_output = gr.Textbox(
label="📝 Text Response",
lines=8,
interactive=False
)
voice_audio_output = gr.Audio(
label="🔊 Audio Response",
type="filepath" # ✅ Important: must be filepath
)
voice_submit_btn.click(
fn=handle_voice_interaction,
inputs=[audio_input, input_language, output_language, voice_selection],
outputs=[voice_text_output, voice_audio_output]
)
gr.Markdown("""
---
### 📌 Features
- **Text Chat**: Fast and simple - just type and get instant responses
- **Voice Chat**: Full support for Nigerian languages!
### 🇳🇬 Supported Nigerian Languages
✅ **English** - 11 different voices (male & female)
✅ **Yoruba** - E ku ọjọ! (4 authentic Yoruba voices)
✅ **Igbo** - Nnọọ! (4 authentic Igbo voices)
✅ **Hausa** - Sannu! (4 authentic Hausa voices)
💡 **All features work in every language:**
- 🎤 Speak your question in your language
- 📝 Get text response translated
- 🔊 Hear authentic audio response in your language
- 🔄 Seamless translation between languages
""")
return demo
# ============================================================================
# ALTERNATIVE: Simpler Hybrid Interface
# ============================================================================
def create_hybrid_interface(
rag_system,
chain,
spitch_api_key: str
):
"""
Creates a simpler interface supporting both text and voice input.
Args:
rag_system: Your initialized WemaRAGSystem
chain: Your LangChain RAG chain (already created)
spitch_api_key: Spitch API key
Returns:
Gradio Interface
"""
assistant = WemaVoiceAssistant(rag_system, chain, spitch_api_key)
def handle_text_query(text_input):
"""Handle text-only query."""
try:
response = chain.invoke({"query": text_input})
return response, None
except Exception as e:
return f"Error: {str(e)}", None
def handle_voice_query(audio, input_lang, output_lang, voice):
"""Handle voice query."""
if audio is None:
return "Please provide audio input.", None
LANGUAGES = {
"English": "en",
"Yoruba": "yo",
"Igbo": "ig",
"Hausa": "ha"
}
input_code = LANGUAGES.get(input_lang, "en")
output_code = LANGUAGES.get(output_lang, "en")
# Process voice query
text_response, audio_bytes = assistant.process_voice_query(
audio,
input_language=input_code,
output_language=output_code,
voice=voice
)
# Convert audio bytes to file path
audio_file_path = None
if audio_bytes:
audio_file_path = save_audio_to_temp_file(audio_bytes)
return text_response, audio_file_path
# Create tabbed interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🏦 Wema Bank AI Assistant")
with gr.Tabs():
# Text Tab
with gr.Tab("💬 Text Chat"):
text_input = gr.Textbox(
label="Type your question",
placeholder="Ask about Wema Bank products and services..."
)
text_submit = gr.Button("Send")
text_output = gr.Textbox(label="Response", lines=10)
text_submit.click(
fn=handle_text_query,
inputs=text_input,
outputs=[text_output, gr.Audio(visible=False)]
)
# Voice Tab
with gr.Tab("🎤 Voice Chat"):
audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath")
with gr.Row():
input_lang = gr.Dropdown(
["English", "Yoruba", "Igbo", "Hausa"],
value="English",
label="Input Language"
)
output_lang = gr.Dropdown(
["English", "Yoruba", "Igbo", "Hausa"],
value="English",
label="Output Language"
)
voice = gr.Dropdown(
["lina", "ada", "kofi"],
value="lina",
label="Voice"
)
voice_submit = gr.Button("Ask")
voice_text_output = gr.Textbox(label="Response Text", lines=8)
voice_audio_output = gr.Audio(label="Audio Response", type="filepath")
voice_submit.click(
fn=handle_voice_query,
inputs=[audio_input, input_lang, output_lang, voice],
outputs=[voice_text_output, voice_audio_output]
)
return demo
@dataclass
class DocumentChunk:
"""Represents a chunk of text with metadata."""
text: str
metadata: Dict
chunk_id: int
class WemaDocumentChunker:
"""Handles intelligent chunking of Wema Bank documents."""
def __init__(self, chunk_size: int = 800, overlap: int = 150):
"""
Initialize the chunker.
Args:
chunk_size: Target size for each chunk in characters
overlap: Number of characters to overlap between chunks
"""
self.chunk_size = chunk_size
self.overlap = overlap
def identify_sections(self, text: str) -> List[Tuple[str, str]]:
"""
Identify logical sections in the document.
Returns:
List of tuples (section_title, section_content)
"""
sections = []
# Common section headers in banking documents
section_patterns = [
r'(Avoiding Financial and Phishing Scams)',
r'(Keeping Your Card.*?Safe)',
r'(E-mails and calls from.*?)',
r'(Scam Alert Tips)',
r'(Guard Yourself)',
r'(Bank Verification Number)',
r'(Personal Banking)',
r'(Business Banking)',
r'(Corporate Banking)',
r'(.*?Account)',
r'(.*?Loan.*?)',
]
# Try to split by recognizable headers
combined_pattern = '|'.join(section_patterns)
matches = list(re.finditer(combined_pattern, text, re.IGNORECASE))
if matches:
for i, match in enumerate(matches):
start = match.start()
end = matches[i + 1].start() if i + 1 < len(matches) else len(text)
section_title = match.group(0).strip()
section_content = text[start:end].strip()
sections.append((section_title, section_content))
else:
# If no clear sections, treat as one section
sections.append(("General Content", text))
return sections
def chunk_text(self, text: str, metadata: Dict) -> List[DocumentChunk]:
"""
Chunk text with semantic awareness and overlap.
Args:
text: Text to chunk
metadata: Metadata to attach to chunks
Returns:
List of DocumentChunk objects
"""
chunks = []
# First, try to identify sections
sections = self.identify_sections(text)
chunk_id = 0
for section_title, section_content in sections:
# If section is smaller than chunk_size, keep it whole
if len(section_content) <= self.chunk_size:
chunk_metadata = metadata.copy()
chunk_metadata['section'] = section_title
chunks.append(DocumentChunk(
text=section_content,
metadata=chunk_metadata,
chunk_id=chunk_id
))
chunk_id += 1
else:
# Split section into smaller chunks with overlap
sentences = self._split_into_sentences(section_content)
current_chunk = []
current_length = 0
for sentence in sentences:
sentence_length = len(sentence)
if current_length + sentence_length > self.chunk_size and current_chunk:
# Save current chunk
chunk_text = ' '.join(current_chunk)
chunk_metadata = metadata.copy()
chunk_metadata['section'] = section_title
chunks.append(DocumentChunk(
text=chunk_text,
metadata=chunk_metadata,
chunk_id=chunk_id
))
chunk_id += 1
# Keep overlap sentences for next chunk
overlap_text = chunk_text[-self.overlap:] if len(chunk_text) > self.overlap else chunk_text
overlap_sentences = self._split_into_sentences(overlap_text)
current_chunk = overlap_sentences
current_length = sum(len(s) for s in current_chunk)
current_chunk.append(sentence)
current_length += sentence_length
# Add remaining chunk
if current_chunk:
chunk_metadata = metadata.copy()
chunk_metadata['section'] = section_title
chunks.append(DocumentChunk(
text=' '.join(current_chunk),
metadata=chunk_metadata,
chunk_id=chunk_id
))
chunk_id += 1
return chunks
def _split_into_sentences(self, text: str) -> List[str]:
"""Split text into sentences."""
# Simple sentence splitter
sentences = re.split(r'(?<=[.!?])\s+', text)
return [s.strip() for s in sentences if s.strip()]
class WemaRAGSystem:
"""Complete RAG system for Wema Bank documents."""
def __init__(self, model_name: str = 'sentence-transformers/all-MiniLM-L6-v2'):
"""
Initialize the RAG system.
Args:
model_name: Name of the sentence transformer model to use
"""
print(f"Loading embedding model: {model_name}")
self.model = SentenceTransformer(model_name)
self.dimension = self.model.get_sentence_embedding_dimension()
self.index = None
self.chunks = []
self.chunker = WemaDocumentChunker()
def load_and_process_document(self, json_path: str):
"""
Load JSON document, chunk it, and create embeddings.
Args:
json_path: Path to the JSON file
"""
print(f"Loading document from: {json_path}")
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# Process each document in the JSON
all_chunks = []
if isinstance(data, list):
documents = data
elif isinstance(data, dict):
documents = [data]
else:
raise ValueError("JSON must contain a document object or list of documents")
for doc in documents:
text = doc.get('text', '')
metadata = {
'url': doc.get('url', ''),
'title': doc.get('title', ''),
'meta_description': doc.get('meta_description', '')
}
# Chunk the document
chunks = self.chunker.chunk_text(text, metadata)
all_chunks.extend(chunks)
print(f"Created {len(chunks)} chunks from document: {metadata['title'][:50]}...")
self.chunks = all_chunks
print(f"Total chunks created: {len(self.chunks)}")
# Generate embeddings
self._create_embeddings()
def _create_embeddings(self):
"""Generate embeddings for all chunks and create FAISS index."""
print("Generating embeddings...")
texts = [chunk.text for chunk in self.chunks]
embeddings = self.model.encode(texts, show_progress_bar=True)
# Create FAISS index
print("Creating FAISS index...")
self.index = faiss.IndexFlatL2(self.dimension)
self.index.add(embeddings.astype('float32'))
print(f"FAISS index created with {self.index.ntotal} vectors")
def save(self, index_path: str = 'wema_faiss.index',
chunks_path: str = 'wema_chunks.pkl'):
"""
Save FAISS index and chunks to disk.
Args:
index_path: Path to save FAISS index
chunks_path: Path to save chunks metadata
"""
if self.index is None:
raise ValueError("No index to save. Process documents first.")
print(f"Saving FAISS index to: {index_path}")
faiss.write_index(self.index, index_path)
print(f"Saving chunks metadata to: {chunks_path}")
with open(chunks_path, 'wb') as f:
pickle.dump(self.chunks, f)
print("Save complete!")
def load(self, index_path: str = 'wema_faiss.index',
chunks_path: str = 'wema_chunks.pkl'):
"""
Load FAISS index and chunks from disk.
Args:
index_path: Path to FAISS index
chunks_path: Path to chunks metadata
"""
print(f"Loading FAISS index from: {index_path}")
self.index = faiss.read_index(index_path)
print(f"Loading chunks metadata from: {chunks_path}")
with open(chunks_path, 'rb') as f:
self.chunks = pickle.load(f)
print(f"Loaded {len(self.chunks)} chunks with index size {self.index.ntotal}")
def search(self, query: str, top_k: int = 5) -> List[Dict]:
"""
Search for relevant chunks given a query.
Args:
query: Search query
top_k: Number of results to return
Returns:
List of dictionaries containing chunk text, metadata, and similarity score
"""
if self.index is None:
raise ValueError("No index loaded. Load or create an index first.")
# Encode query
query_embedding = self.model.encode([query])[0].astype('float32').reshape(1, -1)
# Search
distances, indices = self.index.search(query_embedding, top_k)
# Prepare results
results = []
for i, idx in enumerate(indices[0]):
chunk = self.chunks[idx]
results.append({
'text': chunk.text,
'metadata': chunk.metadata,
'score': float(distances[0][i]),
'chunk_id': chunk.chunk_id
})
return results
def get_context_for_rag(self, query: str, top_k: int = 3,
max_context_length: int = 2000) -> str:
"""
Get formatted context for RAG applications.
Args:
query: Search query
top_k: Number of chunks to retrieve
max_context_length: Maximum length of context to return
Returns:
Formatted context string
"""
results = self.search(query, top_k)
context_parts = []
current_length = 0
for i, result in enumerate(results, 1):
chunk_text = result['text']
section = result['metadata'].get('section', 'N/A')
# Format context with source information
formatted = f"[Source {i} - {section}]\n{chunk_text}\n"
if current_length + len(formatted) > max_context_length:
break
context_parts.append(formatted)
current_length += len(formatted)
return "\n".join(context_parts)
from langchain_core.runnables import RunnablePassthrough, RunnableParallel, RunnableLambda
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_google_genai import ChatGoogleGenerativeAI
import gradio as gr
from typing import Dict, Any, List
import json
class WemaDocumentProcessorRunnable:
"""
Wraps the document loading, chunking, embedding, and storing as a LangChain Runnable.
This preserves ALL the original WemaRAGSystem functionality.
"""
def __init__(self, rag_system):
"""
Initialize with a WemaRAGSystem instance.
Args:
rag_system: An initialized WemaRAGSystem object
"""
self.rag = rag_system
# Create runnables for each step
self.document_loader = RunnableLambda(self._load_document)
self.chunker = RunnableLambda(self._chunk_documents)
self.embedder = RunnableLambda(self._create_embeddings)
self.storer = RunnableLambda(self._store_index)
# Complete pipeline runnable
self.full_pipeline = (
self.document_loader
| self.chunker
| self.embedder
| self.storer
)
def _load_document(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""
Loads JSON document(s).
Args:
inputs: Dictionary with 'json_path' key
Returns:
Dictionary with loaded documents
"""
json_path = inputs.get("json_path", inputs) if isinstance(inputs, dict) else inputs
print(f"Loading document from: {json_path}")
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# Process documents
if isinstance(data, list):
documents = data
elif isinstance(data, dict):
documents = [data]
else:
raise ValueError("JSON must contain a document object or list of documents")
return {
"json_path": json_path,
"documents": documents,
"status": "loaded"
}
def _chunk_documents(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""
Chunks documents using WemaDocumentChunker.
Args:
inputs: Dictionary with 'documents' key
Returns:
Dictionary with chunked documents
"""
documents = inputs["documents"]
print("Chunking documents...")
all_chunks = []
for doc in documents:
text = doc.get('text', '')
metadata = {
'url': doc.get('url', ''),
'title': doc.get('title', ''),
'meta_description': doc.get('meta_description', '')
}
# Use the original chunker from WemaRAGSystem
chunks = self.rag.chunker.chunk_text(text, metadata)
all_chunks.extend(chunks)
print(f"Created {len(chunks)} chunks from document: {metadata['title'][:50]}...")
self.rag.chunks = all_chunks
print(f"Total chunks created: {len(self.rag.chunks)}")
return {
"json_path": inputs.get("json_path"),
"documents": documents,
"chunks": all_chunks,
"chunk_count": len(all_chunks),
"status": "chunked"
}
def _create_embeddings(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""
Creates embeddings and FAISS index using the original method.
Args:
inputs: Dictionary with 'chunks' key
Returns:
Dictionary with embedding info
"""
print("Generating embeddings...")
# Use the original _create_embeddings method
self.rag._create_embeddings()
return {
"json_path": inputs.get("json_path"),
"documents": inputs["documents"],
"chunks": inputs["chunks"],
"chunk_count": inputs["chunk_count"],
"index_size": self.rag.index.ntotal,
"status": "embedded"
}
def _store_index(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""
Saves FAISS index and chunks to disk.
Args:
inputs: Dictionary with processing results
Returns:
Dictionary with save status
"""
index_path = inputs.get("index_path", "wema_faiss.index")
chunks_path = inputs.get("chunks_path", "wema_chunks.pkl")
# Use the original save method
self.rag.save(index_path=index_path, chunks_path=chunks_path)
return {
"json_path": inputs.get("json_path"),
"chunk_count": inputs["chunk_count"],
"index_size": inputs["index_size"],
"index_path": index_path,
"chunks_path": chunks_path,
"status": "saved"
}
def get_full_pipeline(self):
"""Returns the complete processing pipeline as a LangChain Runnable."""
return self.full_pipeline
def get_loader_runnable(self):
"""Returns just the document loader."""
return self.document_loader
def get_chunker_runnable(self):
"""Returns just the chunker."""
return self.chunker
def get_embedder_runnable(self):
"""Returns just the embedder."""
return self.embedder
def get_storer_runnable(self):
"""Returns just the storer."""
return self.storer
class WemaRAGRetrieverRunnable:
"""
Wraps the retrieval functionality as a LangChain Runnable.
"""
def __init__(self, rag_system):
"""
Initialize with an existing WemaRAGSystem instance.
Args:
rag_system: An initialized WemaRAGSystem object
"""
self.rag = rag_system
self.retriever = RunnableLambda(self._retrieve_context)
def _retrieve_context(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""
Retrieves context from the RAG system using the original search method.
Args:
inputs: Dictionary containing 'query' key
Returns:
Dictionary with query and context
"""
query = inputs.get("query", inputs) if isinstance(inputs, dict) else inputs
# Use the original get_context_for_rag method
context = self.rag.get_context_for_rag(query, top_k=3)
return {
"query": query,
"context": context
}
def get_retriever_runnable(self):
"""Returns the retriever as a LangChain Runnable."""
return self.retriever
class WemaRAGLoaderRunnable:
"""
Wraps the loading functionality as a LangChain Runnable.
"""
def __init__(self, rag_system):
"""
Initialize with a WemaRAGSystem instance.
Args:
rag_system: An initialized WemaRAGSystem object
"""
self.rag = rag_system
self.loader = RunnableLambda(self._load_index)
def _load_index(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""
Loads FAISS index and chunks from disk using the original method.
Args:
inputs: Dictionary with 'index_path' and 'chunks_path' keys
Returns:
Dictionary with load status
"""
index_path = inputs.get("index_path", "wema_faiss.index")
chunks_path = inputs.get("chunks_path", "wema_chunks.pkl")
# Use the original load method
self.rag.load(index_path=index_path, chunks_path=chunks_path)
return {
"index_path": index_path,
"chunks_path": chunks_path,
"chunk_count": len(self.rag.chunks),
"index_size": self.rag.index.ntotal,
"status": "loaded"
}
def get_loader_runnable(self):
"""Returns the loader as a LangChain Runnable."""
return self.loader
def create_wema_rag_chain(rag_system, google_api_key: str):
"""
Creates a complete LangChain RAG chain using the WemaRAGSystem.
Args:
rag_system: An initialized WemaRAGSystem object
google_api_key: Google API key for Gemini
Returns:
A LangChain runnable chain
"""
# Wrap the RAG retriever as a runnable
wema_retriever = WemaRAGRetrieverRunnable(rag_system)
# Initialize the LLM
llm = ChatGoogleGenerativeAI(
model="gemini-2.0-flash-exp",
google_api_key=google_api_key,
temperature=0.7
)
# Create the prompt template
prompt = ChatPromptTemplate.from_template("""
You are **Wema Assist**, an intelligent, reliable, and customer-friendly virtual assistant for **Wema Bank**.
Your goal is to provide accurate, thoughtful, and helpful responses to users — just like a top-tier banking representative — but with the warmth and empathy of a human advisor.
You have access to verified Wema Bank documentation provided below as *Context*.
Your tone should be:
- professional yet conversational,
- confident but not pushy,
- and focused on solving the user's need while representing Wema Bank's products effectively.
---
### 🎯 Core Guidelines:
1. **Use the provided "Context" first.**
- It contains factual information from official Wema Bank documents.
- Do not invent product details — rely on the context where possible.
2. **Always connect answers to Wema Bank offerings.**
- When relevant, mention products or services such as:
- Wema Savings Account
- ALAT Digital Bank or ALAT Savings Goals
- Wema Kiddies Account
- Business or SME Banking
- Wema Loans
- Wema Security Tips or Scam Alerts
- Even if the user query seems general, highlight any Wema product that could help.
3. **Be natural and practical.**
- Offer useful, step-by-step guidance.
- Use phrasing like:
- "At Wema Bank, you can..."
- "A good option through Wema is..."
- "Wema's ALAT platform allows you to..."
4. **If the context isn't related to the query:**
- Simply give a general, thoughtful answer — *without apologizing or saying the context is irrelevant.*
---
### 📘 Information You Have:
**Context:**
{context}
**User Query:**
{query}
---
### 🧠 Task:
Answer the query in a complete, natural, and customer-friendly way — integrating Wema Bank products or services wherever relevant.
If the RAG and context are not related, just give a general answer and don't complain.
### 💬 Final Answer:
""")
# Build the chain using LCEL (LangChain Expression Language)
chain = (
RunnablePassthrough()
| wema_retriever.get_retriever_runnable()
| prompt
| llm
| StrOutputParser()
)
return chain
def create_gradio_interface(rag_system, google_api_key: str):
"""
Creates a Gradio interface using the LangChain RAG chain.
Args:
rag_system: An initialized WemaRAGSystem object
google_api_key: Google API key for Gemini
Returns:
Gradio Interface object
"""
# Create the LangChain chain
chain = create_wema_rag_chain(rag_system, google_api_key)
def chat_function(query: str) -> str:
"""Wrapper function for Gradio."""
try:
response = chain.invoke({"query": query})
return response
except Exception as e:
return f"An error occurred: {str(e)}"
# Create Gradio interface
iface = gr.Interface(
fn=chat_function,
inputs=gr.Textbox(
label="Enter your query about Wema Bank:",
placeholder="Ask me anything about Wema Bank products and services..."
),
outputs=gr.Textbox(
label="Wema Assist Response:",
lines=10
),
title="🏦 Wema Bank RAG Chatbot (LangChain Edition)",
description="Powered by LangChain and your custom Wema RAG System",
theme="soft"
)
return iface
# Initialize RAG system
rag = WemaRAGSystem()
# Wrap it as a LangChain runnable
processor = WemaDocumentProcessorRunnable(rag)
# Cell 3: Run the complete pipeline (load → chunk → embed → store)
result = processor.get_full_pipeline().invoke({
"json_path": "wema_cleaned.json",
"index_path": "wema_faiss.index",
"chunks_path": "wema_chunks.pkl"
})
print(f"Processing complete!")
print(f"Chunks created: {result['chunk_count']}")
print(f"Index size: {result['index_size']}")
print(f"Saved to: {result['index_path']}")
# Assuming you have an instance of WemaRAGSystem called 'rag'
#rag = WemaRAGSystem()
# Replace 'your_document.json' with the actual path to your file
#rag.load_and_process_document("your_document.json")
"""
# Cell 4: Create and launch Gradio interface
from google.colab import userdata
GOOGLE_API_KEY = userdata.get('GOOGLE_API_KEY')
iface = create_gradio_interface(rag, GOOGLE_API_KEY)
iface.launch()
"""
'''
# Cell 2: Set up your RAG system (your existing code)
rag = WemaRAGSystem()
rag.load() # Load your existing index
# Cell 3: Initialize API keys
from google.colab import userdata
GOOGLE_API_KEY = userdata.get('GOOGLE_API_KEY')
SPITCH_API_KEY = userdata.get('SPITCH_API_KEY') # Add this to your Colab secrets
# Cell 4: Launch voice interface
iface = create_voice_gradio_interface(
rag_system=rag,
google_api_key=GOOGLE_API_KEY,
spitch_api_key=SPITCH_API_KEY
)
iface.launch(share=True)
'''
# Cell 2: Set up your RAG system (your existing code)
rag = WemaRAGSystem()
rag.load() # Load your existing index
# Cell 3: Initialize API keys
import os
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
SPITCH_API_KEY = os.getenv("SPITCH_API_KEY")
if not GOOGLE_API_KEY or not SPITCH_API_KEY:
raise ValueError("Missing one or more API keys. Make sure they are added as secrets in your Space.")
# Cell 4: Launch voice interface
# The create_voice_gradio_interface function needs the chain, not the google_api_key directly.
# We need to create the chain first.
chain = create_wema_rag_chain(rag, GOOGLE_API_KEY)
iface = create_voice_gradio_interface(
rag_system=rag,
chain=chain, # Pass the created chain
spitch_api_key=SPITCH_API_KEY
)
iface.launch(share=True, debug=True)