Ronochieng's picture
Update app.py
c6ffb8c verified
import streamlit as st
import speech_recognition as sr
import tempfile
import os, sys
import openai
import json
import requests
import base64
from io import BytesIO
from requests.auth import HTTPBasicAuth
from typing import Tuple, Dict, List, Optional
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain.chains import ConversationalRetrievalChain, RetrievalQA
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain.output_parsers import PydanticOutputParser
from pydantic import BaseModel, Field
from langdetect import detect
from audio_recorder_streamlit import audio_recorder
from dotenv import load_dotenv, find_dotenv
import pandas as pd
import pickle
import time
sys.path.append("../..")
_ = load_dotenv(find_dotenv())
# Constants
DB_FAISS_PATH = 'vectorstore/db_faiss'
API_USERNAME = os.getenv('API_USERNAME')
API_PASSWORD = os.getenv('API_PASSWORD')
BASE_URL = os.getenv('BASE_URL')
openai.api_key = os.environ["OPENAI_API_KEY"]
# Technical terms to keep in English
TECHNICAL_TERMS = [
"Dayliff", "Pedrollo", "Grundfos", "PSK", "PV", "AC", "DC", "pH",
"kW", "HP", "VAC", "Hz", "RPM", "IP", "ISO", "KEBS", "m³/hr"
]
class Component(BaseModel):
no: str = Field(..., description="Product number")
product_model: str = Field(..., description="Model of the product")
item_category_code: str = Field(..., description="Category code")
description: str = Field(..., description="Component description")
quantity: int = Field(..., description="Quantity needed")
unit_price: float = Field(0.0, description="Price per unit")
gross_price: float = Field(0.0, description="Total cost")
class ProductResponse(BaseModel):
components: List[Component]
subtotal: float = Field(0.0, description="Subtotal before VAT")
vat: float = Field(0.0, description="VAT amount")
total: float = Field(0.0, description="Total cost including VAT")
explanation: str = Field(..., description="Detailed explanation")
additional_notes: Optional[str] = None
def process_audio_with_openai(audio_bytes: bytes, target_language: str, proficiency_level: str) -> Tuple[str, str, bytes]:
"""Enhanced audio processing using OpenAI's GPT-4 Audio model, considering proficiency level"""
api_key = os.getenv("OPENAI_API_KEY")
headers = {
"Authorization": f"Bearer {api_key}"
}
# Get transcription using the Whisper API instead of chat completions
with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_audio:
temp_audio.write(audio_bytes)
temp_audio.flush()
files = {"file": open(temp_audio.name, "rb")}
transcription_response = requests.post(
"https://api.openai.com/v1/audio/transcriptions",
headers={"Authorization": f"Bearer {api_key}"},
files=files,
data={"model": "whisper-1"}
)
transcription_data = transcription_response.json()
if "text" not in transcription_data:
raise Exception(f"Unexpected API response: {transcription_data}")
original_text = transcription_data["text"]
# Get translation with technical terms preserved and appropriate for proficiency level
translation_prompt = f"Translate to {target_language}, keeping technical terms unchanged: {', '.join(TECHNICAL_TERMS)}. "
translation_prompt += f"Adapt the language for a {proficiency_level.lower()} level of technical understanding."
translation_response = requests.post(
"https://api.openai.com/v1/chat/completions",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={
"model": "gpt-4o-mini", # Changed from gpt-4o-audio-preview to gpt-4o-mini for text translation
"messages": [
{"role": "system", "content": translation_prompt},
{"role": "user", "content": original_text}
]
}
)
translation_data = translation_response.json()
if "choices" not in translation_data or len(translation_data["choices"]) == 0:
raise Exception(f"Unexpected translation API response: {translation_data}")
translated_text = translation_data['choices'][0]['message']['content']
# Generate translated audio
audio_response = requests.post(
"https://api.openai.com/v1/audio/speech",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={
"model": "tts-1",
"input": translated_text,
"voice": "alloy",
"speed": 0.9 if proficiency_level == "Beginner" else 1.0 # Slower for beginners
}
)
if audio_response.status_code != 200:
raise Exception(f"Error generating audio: {audio_response.text}")
translated_audio = audio_response.content
return original_text, translated_text, translated_audio
def initialize_chatbot(proficiency_level: str):
"""Initialize enhanced RAG system with memory, considering proficiency level"""
embedding_model = OpenAIEmbeddings(model="text-embedding-3-large")
db = FAISS.load_local(DB_FAISS_PATH, embedding_model, allow_dangerous_deserialization=True)
faiss_retriever = db.as_retriever()
memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True
)
proficiency_instructions = {
"Beginner": "Use simple language and avoid technical jargon. Provide basic explanations.",
"Intermediate": "Use a balanced mix of technical and simplified language. Provide moderate detail.",
"Advanced": "You can use technical language and provide detailed, in-depth explanations."
}
# Create a prompt with the proficiency instruction directly included
instruction = proficiency_instructions[proficiency_level]
prompt_template = f"""You are a Davis & Shirtliff product expert. {instruction}
Use this context to answer: {{question}}
Context: {{context}}
Answer:"""
prompt = PromptTemplate(
template=prompt_template,
input_variables=["question", "context"]
)
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0.3)
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type='stuff',
retriever=faiss_retriever,
return_source_documents=True,
chain_type_kwargs={'prompt': prompt}
)
return qa_chain
def get_product_details(no: str) -> Dict:
"""Fetch comprehensive product details"""
params = {"$filter": f"No eq '{no}'"}
try:
response = requests.get(
BASE_URL,
params=params,
auth=HTTPBasicAuth(API_USERNAME, API_PASSWORD)
)
response.raise_for_status()
data = response.json()
if 'value' in data and data['value']:
item = data['value'][0]
return {
'no': item.get('No', ''),
'inventory': int(item.get('Inventory', 0)),
'unit_price': float(item.get('Unit_Price', 0)),
'description': item.get('Description', ''),
'item_category_code': item.get('Item_Category_Code', ''),
'product_model': item.get('Product_Model', ''),
'specifications': item.get('Technical_Specifications', ''),
'warranty': item.get('Warranty_Period', '')
}
return {}
except requests.RequestException as e:
st.error(f"Error fetching product details: {str(e)}")
return {}
def process_text_input(user_input: str, target_language: str, proficiency_level: str, qa_chain) -> Tuple[str, bytes]:
"""Process text input and return formatted response with audio"""
response = qa_chain({"query": user_input})
answer = response.get('result', '')
sources = response.get('source_documents', [])
# Check for product numbers and enhance with details
for doc in sources:
if hasattr(doc, 'metadata') and 'product_no' in doc.metadata:
details = get_product_details(doc.metadata['product_no'])
if details:
answer += f"\n\nProduct Details:\nModel: {details['product_model']}\n"
answer += f"Stock: {details['inventory']}\n"
answer += f"Specifications: {details['specifications']}\n"
answer += f"Warranty: {details['warranty']}"
# Translate if needed
if target_language != "English":
# Adjust complexity based on proficiency level
complexity_instruction = {
"Beginner": "Use simple language and avoid technical jargon.",
"Intermediate": "Use a balanced mix of technical and simplified language.",
"Advanced": "You can use technical language and detailed explanations."
}
translation_response = requests.post(
"https://api.openai.com/v1/chat/completions",
headers={"Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}"},
json={
"model": "gpt-4o-mini",
"messages": [
{"role": "system", "content": f"Translate to {target_language}, preserving technical terms: {', '.join(TECHNICAL_TERMS)}. {complexity_instruction[proficiency_level]}"},
{"role": "user", "content": answer}
]
}
)
answer = translation_response.json()['choices'][0]['message']['content']
# Generate audio for the answer
audio_response = requests.post(
"https://api.openai.com/v1/audio/speech",
headers={"Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}"},
json={
"model": "tts-1",
"input": answer,
"voice": "alloy",
"speed": 0.9 if proficiency_level == "Beginner" else 1.0 # Slower for beginners
}
)
answer_audio = audio_response.content
return answer, answer_audio
def display_chat_message(is_user: bool, message: str, audio_bytes=None, is_loading=False):
"""Display a chat message with modern styling and avatar"""
message_class = "user-message" if is_user else "assistant-message"
avatar_class = "user-avatar" if is_user else "assistant-avatar"
content_class = "user-content" if is_user else "assistant-content"
avatar_text = "👤" if is_user else "🤖"
col1, col2, col3 = st.columns([0.1, 0.8, 0.1])
with col2:
if is_loading:
st.markdown(f"""
<div class="{message_class}">
<div class="message-avatar {avatar_class}">{avatar_text}</div>
<div class="message-content {content_class} loading-message">
<div class="typing-indicator">
<span></span>
<span></span>
<span></span>
</div>
</div>
</div>
""", unsafe_allow_html=True)
else:
st.markdown(f"""
<div class="{message_class}">
<div class="message-avatar {avatar_class}">{avatar_text}</div>
<div class="message-content {content_class}">
{message}
</div>
</div>
""", unsafe_allow_html=True)
# Audio player for assistant messages
if not is_user and audio_bytes:
st.markdown('<div class="audio-player">', unsafe_allow_html=True)
st.audio(audio_bytes, format='audio/wav')
st.markdown('</div>', unsafe_allow_html=True)
def set_page_style():
"""Set custom page styling for a modern chat interface with theme compatibility"""
st.markdown("""
<style>
/* Theme-responsive styles */
.chat-container {
padding: 10px 0;
}
.user-message {
display: flex;
align-items: flex-start;
margin-bottom: 24px;
}
.assistant-message {
display: flex;
align-items: flex-start;
margin-bottom: 24px;
flex-direction: row-reverse;
}
.message-avatar {
width: 40px;
height: 40px;
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
font-size: 18px;
color: white;
}
.user-avatar {
background-color: var(--primary-color, #e91e63);
margin-right: 12px;
}
.assistant-avatar {
background-color: #795548;
margin-left: 12px;
}
.message-content {
background-color: var(--secondary-background-color, rgba(128, 128, 128, 0.15));
padding: 12px 16px;
border-radius: 18px;
max-width: 75%;
color: var(--text-color, inherit);
}
.user-content {
border-top-left-radius: 4px;
}
.assistant-content {
border-top-right-radius: 4px;
}
.audio-player {
margin-top: 8px;
width: 100%;
border-radius: 12px;
overflow: hidden;
}
.stAudio {
width: 100% !important;
}
.stAudio > div {
border-radius: 12px !important;
}
.title-container {
text-align: center;
padding: 15px;
background-color: var(--primary-color, #1976d2);
border-radius: 10px;
color: white;
margin-bottom: 20px;
}
/* Improved input container with proper alignment and theme compatibility */
.input-area {
display: flex;
align-items: center;
margin-top: 20px;
gap: 10px;
background-color: var(--input-bg-color, rgba(128, 128, 128, 0.1));
border-radius: 24px;
padding: 8px 16px;
width: 100%;
border: 1px solid var(--input-border-color, rgba(128, 128, 128, 0.2));
}
.input-area .stTextInput {
flex-grow: 1;
}
.stTextInput>div>div>input {
background-color: transparent !important;
border: none !important;
padding: 8px 0 !important;
box-shadow: none !important;
}
/* Remove padding and margin from the container columns */
.input-container-col .stTextInput {
margin-bottom: 0 !important;
}
.button-col div {
display: flex;
justify-content: flex-end;
}
.send-button {
background-color: var(--primary-color, #1976d2);
color: white;
border-radius: 50%;
width: 40px !important;
height: 40px !important;
display: flex !important;
align-items: center !important;
justify-content: center !important;
padding: 0 !important;
min-height: 0 !important;
}
/* Loading indicator animation */
.loading-message {
min-width: 70px;
}
.typing-indicator {
display: flex;
align-items: center;
justify-content: center;
}
.typing-indicator span {
height: 8px;
width: 8px;
margin: 0 2px;
background-color: var(--text-color, #9E9E9E);
display: block;
border-radius: 50%;
opacity: 0.4;
}
.typing-indicator span:nth-of-type(1) {
animation: typing 1s infinite;
}
.typing-indicator span:nth-of-type(2) {
animation: typing 1s 0.2s infinite;
}
.typing-indicator span:nth-of-type(3) {
animation: typing 1s 0.4s infinite;
}
@keyframes typing {
0% {
transform: translateY(0px);
opacity: 0.4;
}
50% {
transform: translateY(-5px);
opacity: 0.8;
}
100% {
transform: translateY(0px);
opacity: 0.4;
}
}
/* Align the columns properly */
.stHorizontal .stColumn {
padding-left: 0 !important;
padding-right: 0 !important;
}
/* Add CSS variables for theme detection */
:root {
--primary-color: #1976d2;
--secondary-background-color: rgba(128, 128, 128, 0.15);
--text-color: inherit;
--input-bg-color: rgba(128, 128, 128, 0.1);
--input-border-color: rgba(128, 128, 128, 0.2);
}
/* Dark mode specific adjustments */
@media (prefers-color-scheme: dark) {
:root {
--secondary-background-color: rgba(70, 70, 70, 0.3);
--input-bg-color: rgba(70, 70, 70, 0.2);
--input-border-color: rgba(100, 100, 100, 0.3);
}
.message-content {
color: rgba(255, 255, 255, 0.9);
}
.stTextInput>div>div>input {
color: rgba(255, 255, 255, 0.9) !important;
}
}
</style>
""", unsafe_allow_html=True)
def main():
set_page_style()
# Sidebar configuration
with st.sidebar:
st.markdown("<h2 style='text-align: center;'>Control Panel</h2>", unsafe_allow_html=True)
st.markdown("<p>Language Settings</p>", unsafe_allow_html=True)
target_language = st.selectbox(
"Select language:",
["English", "Swahili", "Kikuyu", "Luo", "Kamba", "Kalenjin", "Luhya", "French", "Kinyarwanda", "Spanish", "Arabic"],
key="language_selector"
)
st.markdown("<p>Proficiency Level</p>", unsafe_allow_html=True)
proficiency_level = st.radio(
"Select your technical understanding:",
["Beginner", "Intermediate", "Advanced"],
key="proficiency_level"
)
st.markdown("<p>Input Method</p>", unsafe_allow_html=True)
input_method = st.radio(
"Choose input method:",
["Text", "Voice"],
key="input_method"
)
if st.button("Clear Conversation", key="clear_button"):
st.session_state.chat_history = []
st.rerun()
# Main content area
st.markdown("""
<div class="title-container">
<h1>D&S Product Assistant</h1>
<p>Ask about D&S products in any language through text or voice! First select your language and then set your proficiency level through the side bar options.</p>
</div>
""", unsafe_allow_html=True)
# Initialize session state
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
if 'qa_chain' not in st.session_state or st.session_state.get('current_proficiency') != proficiency_level:
st.session_state.qa_chain = initialize_chatbot(proficiency_level)
st.session_state.current_proficiency = proficiency_level
if 'processing' not in st.session_state:
st.session_state.processing = False
if 'input_key' not in st.session_state:
st.session_state.input_key = 0
# Chat display container
chat_container = st.container()
with chat_container:
st.markdown('<div class="chat-container">', unsafe_allow_html=True)
for msg in st.session_state.chat_history:
display_chat_message(msg[0], msg[1], msg[2] if len(msg) > 2 else None)
# Show loading indicator if processing
if st.session_state.processing:
display_chat_message(False, "", None, True)
st.markdown('</div>', unsafe_allow_html=True)
# Input area - improved alignment
st.markdown('<div class="input-area">', unsafe_allow_html=True)
# Create columns for input and button with CSS classes
input_col, button_col = st.columns([0.9, 0.1])
# Add CSS class to columns
input_col.markdown('<div class="input-container-col">', unsafe_allow_html=True)
button_col.markdown('<div class="button-col">', unsafe_allow_html=True)
with input_col:
if input_method == "Text":
user_input = st.text_input(
"",
key=f"text_input_{st.session_state.input_key}", # Use a dynamic key
placeholder="Ask about D&S products...",
label_visibility="collapsed"
)
else:
st.markdown("""
<p style="margin: 0; color: var(--text-color, inherit);">📢 Record your question:</p>
""", unsafe_allow_html=True)
audio_bytes = audio_recorder(
pause_threshold=2.0,
sample_rate=16000,
text="🎤",
neutral_color="var(--primary-color, #1976d2)",
recording_color="#e91e63"
)
with button_col:
if input_method == "Text":
send_clicked = st.button("↑", key="send_button", help="Send message", type="primary", use_container_width=False)
# Close column div tags
input_col.markdown('</div>', unsafe_allow_html=True)
button_col.markdown('</div>', unsafe_allow_html=True)
# Close input area div
st.markdown('</div>', unsafe_allow_html=True)
# Process text input if send button is clicked
if input_method == "Text" and user_input and send_clicked:
# Show user message
with chat_container:
display_chat_message(True, user_input)
# Save to history
st.session_state.chat_history.append((True, user_input))
# Set processing state
st.session_state.processing = True
# Increment the input key to clear the input
st.session_state.input_key += 1
st.rerun()
# Process text input if in processing state
if st.session_state.processing and input_method == "Text":
# Process the last user input
last_user_input = st.session_state.chat_history[-1][1]
# Get assistant response
answer, answer_audio = process_text_input(
last_user_input, target_language, proficiency_level, st.session_state.qa_chain
)
# Add to history and clear processing state
st.session_state.chat_history.append((False, answer, answer_audio))
st.session_state.processing = False
st.rerun()
# Process voice input if provided
if input_method == "Voice" and audio_bytes:
# Display user message first
with st.spinner("Processing your voice input..."):
try:
original_text, translated_text, translated_audio = process_audio_with_openai(
audio_bytes, target_language, proficiency_level
)
# Add user message to history
st.session_state.chat_history.append((True, original_text))
# Set processing state
st.session_state.processing = True
st.rerun()
except Exception as e:
st.error(f"Error processing audio: {str(e)}")
# Process voice response if in processing state
if st.session_state.processing and input_method == "Voice":
# Check if we have a user message to respond to
if st.session_state.chat_history and st.session_state.chat_history[-1][0]:
last_user_input = st.session_state.chat_history[-1][1]
# Get response
answer, answer_audio = process_text_input(
last_user_input, target_language, proficiency_level, st.session_state.qa_chain
)
# Add to history and clear processing state
st.session_state.chat_history.append((False, answer, answer_audio))
st.session_state.processing = False
st.rerun()
def process_and_display_response(user_input, target_language, proficiency_level, chat_container):
"""Process user input and display the response"""
with st.spinner("Processing your question..."):
# Display user message
with chat_container:
display_chat_message(True, user_input)
# Get response
answer, answer_audio = process_text_input(
user_input, target_language, proficiency_level, st.session_state.qa_chain
)
# Display assistant message
with chat_container:
display_chat_message(False, answer, answer_audio)
# Save to history
st.session_state.chat_history.append((True, user_input))
st.session_state.chat_history.append((False, answer, answer_audio))
if __name__ == "__main__":
main()