Spaces:
Runtime error
Runtime error
File size: 3,835 Bytes
db7016c 10f44c3 db7016c 10f44c3 db7016c 7e8bfb7 db7016c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 | import gradio as gr
import logging
from backend.train import ModelTrainer
from backend.rag import PineconeRetriever
from dotenv import load_dotenv
import os
load_dotenv()
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
class ChatbotUI:
"""Class for the Gradio-based chatbot UI with RAG enabled by default."""
def __init__(self, model_name, pinecone_api, pinecone_index, pinecone_namespace):
logging.info("Initializing ChatbotUI...")
self.trainer = ModelTrainer(model_name)
self.retriever = PineconeRetriever(pinecone_api, pinecone_index, pinecone_namespace)
self.use_rag = True # RAG enabled by default
logging.info("ChatbotUI initialized successfully with RAG enabled.")
def chatbot_response(self, input_text):
"""Generate response using retrieved context and the trained model."""
# Retrieve relevant context from Pinecone
retrieved_docs = self.retriever.retrieve_context(input_text, top_k=1)
preset = ("You are a HR legal assistant. You do not respond as 'User' or pretend to be 'User'. "
"You only respond once as 'Assistant'. Avoid Yes or No.")
latest_prompt = (f"{preset}\n\n### Context: {retrieved_docs.strip()}\n\n"
f"### User: {input_text.strip()}\n\n### Response:")
inputs = self.trainer.tokenizer(latest_prompt, return_tensors="pt")
outputs = self.trainer.model.generate(
**inputs,
max_new_tokens=2000, # Reduce token size to optimize speed
do_sample=True,
top_p=0.95,
temperature=0.1,
repetition_penalty=1.2,
)
response = self.trainer.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
if response.startswith(latest_prompt):
response = response[len(latest_prompt):].strip()
return response
def clear_conversation(self):
"""Clears the entire conversation history and resets the input box."""
logging.info("Clearing conversation history.")
return [], "", []
def launch(self):
logging.info("Launching chatbot UI...")
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("""
<h1 style='text-align: center; color: #2E3A59;'>HAP Chatbot</h1>
<p style='text-align: center; color: #4A5568;'>An intelligent HR legal assistant powered by AI.</p>
""")
with gr.Row():
chatbot = gr.Chatbot(label="Chat", height=400)
with gr.Row():
msg = gr.Textbox(label="Your Message", placeholder="Enter your message here...", lines=2)
send_btn = gr.Button("Send", variant="primary")
with gr.Row():
clear_btn = gr.Button("Clear Chat", variant="secondary")
state = gr.State([])
def user_message(message, history):
response = self.chatbot_response(message)
history = history + [(message, response)]
# Return updated conversation history, clear the textbox, and update the state
return history, "", history
send_btn.click(user_message, inputs=[msg, state], outputs=[chatbot, msg, state])
clear_btn.click(self.clear_conversation, inputs=[], outputs=[chatbot, msg, state], queue=False)
demo.launch()
if __name__ == "__main__":
chatbot = ChatbotUI(
model_name="sainoforce/modelv5",
pinecone_api=os.getenv("PINECONE_API_KEY"),
pinecone_index=os.getenv("PINECONE_INDEX"),
pinecone_namespace=os.getenv("PINECONE_NAMESPACE")
)
chatbot.launch()
|