birdsGPT / app.py
Gopikanth123's picture
Update app.py
d8b0e8d verified
from dotenv import load_dotenv
import gradio as gr
import os
from llama_index.core import StorageContext, load_index_from_storage, VectorStoreIndex, SimpleDirectoryReader, ChatPromptTemplate, Settings
from llama_index.llms.huggingface import HuggingFaceInferenceAPI
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
import random
def select_random_name():
names = ['bot', 'user']
return random.choice(names)
# Load environment variables
load_dotenv()
# Configure Llama index settings
Settings.llm = HuggingFaceInferenceAPI(
model_name="meta-llama/Meta-Llama-3-8B-Instruct",
tokenizer_name="meta-llama/Meta-Llama-3-8B-Instruct",
context_window=3000,
token=os.getenv("HF_TOKEN"),
max_new_tokens=512,
generate_kwargs={"temperature": 0.1},
)
Settings.embed_model = HuggingFaceEmbedding(
model_name="BAAI/bge-small-en-v1.5"
)
# Define the directory for persistent storage and data
PERSIST_DIR = "db"
PDF_DIRECTORY = 'data' # Changed to the directory containing PDFs
# Ensure directories exist
os.makedirs(PDF_DIRECTORY, exist_ok=True)
os.makedirs(PERSIST_DIR, exist_ok=True)
# Variable to store current chat conversation
current_chat_history = []
kkk = select_random_name()
def data_ingestion_from_directory():
# Load documents from the directory containing the PDF files
documents = SimpleDirectoryReader(PDF_DIRECTORY).load_data()
storage_context = StorageContext.from_defaults()
index = VectorStoreIndex.from_documents(documents)
index.storage_context.persist(persist_dir=PERSIST_DIR)
def handle_query(query):
chat_text_qa_msgs = [
(
"user",
"""
You are the Disease chatbot, known as Disease Helper. Your goal is to provide accurate and professional answers to user queries based on the information available about the Diseases. Always respond clearly and concisely, ideally within 10-15 words. If you don't know the answer, say so politely.
{context_str}
Question:
{query_str}
"""
)
]
text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs)
# Load index from storage
storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
index = load_index_from_storage(storage_context)
# Use chat history to enhance response
context_str = ""
for past_query, response in reversed(current_chat_history):
if past_query.strip():
context_str += f"User asked: '{past_query}'\nBot answered: '{response}'\n"
query_engine = index.as_query_engine(text_qa_template=text_qa_template, context_str=context_str)
answer = query_engine.query(query)
if hasattr(answer, 'response'):
response = answer.response
elif isinstance(answer, dict) and 'response' in answer:
response = answer['response']
else:
# Generate a general response if information is not available
response = "Sorry, I don’t have that information available in the provided data. Let me provide a general response."
llm_response = Settings.llm.query(query)
if hasattr(llm_response, 'response'):
response = llm_response.response
elif isinstance(llm_response, dict) and 'response' in llm_response:
response = llm_response['response']
# Update current chat history
current_chat_history.append((query, response))
return response
def predict(message, history):
logo_html = '''
<div class="circle-logo">
<img src="https://rb.gy/8r06eg" alt="FernAi" style="width: 100%; height: auto;">
</div>
'''
response = handle_query(message)
response_with_logo = f'''
<div class="response-with-logo">
{logo_html}
<div class="response-text">
{response}
</div>
</div>
'''
return response_with_logo
# Define Gradio chat interface function
def chat_interface(message, history):
try:
response = handle_query(message)
return response
except Exception as e:
return str(e)
# Custom CSS for styling
css = '''
body {
background: linear-gradient(135deg, #e0f7fa, #ffebee); /* Gradient background */
font-family: 'Arial', sans-serif;
color: #333;
margin: 0;
padding: 20px;
}
.chatbox {
display: flex;
flex-direction: column;
max-width: 600px;
margin: auto;
border-radius: 15px;
background-color: #ffffff;
box-shadow: 0 4px 20px rgba(0, 0, 0, 0.1);
padding: 20px;
}
.circle-logo {
display: inline-block;
width: 60px;
height: 60px;
border-radius: 50%;
overflow: hidden;
margin-right: 10px;
vertical-align: middle;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.2);
}
.response-with-logo {
display: flex;
align-items: center;
margin-bottom: 20px;
padding: 10px;
background-color: #f9f9f9;
border-radius: 10px;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
transition: background-color 0.3s;
}
.response-with-logo:hover {
background-color: #e0f7fa; /* Light blue on hover */
}
.response-text {
font-size: 18px;
font-weight: bold;
color: #333;
}
footer {
display: none !important;
}
label.svelte-1b6s6s {display: none;}
div.svelte-rk35yg {display: none;}
div.progress-text.svelte-z7cif2.meta-text {display: none;}
@media (max-width: 600px) {
.chatbox {
width: 90%; /* Responsive width for smaller screens */
}
}
'''
# Launch the Gradio chat interface
print("Processing PDF ingestion from directory:", PDF_DIRECTORY)
data_ingestion_from_directory()
# Create the Gradio interface
gr.ChatInterface(chat_interface,
css=css,
description="Disease Helper Chatbot",
clear_btn=None, undo_btn=None, retry_btn=None
).launch()