import gradio as gr import logging from backend.train import ModelTrainer from backend.rag import PineconeRetriever from dotenv import load_dotenv import os load_dotenv() # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') class ChatbotUI: """Class for the Gradio-based chatbot UI with RAG enabled by default.""" def __init__(self, model_name, pinecone_api, pinecone_index, pinecone_namespace): logging.info("Initializing ChatbotUI...") self.trainer = ModelTrainer(model_name) self.retriever = PineconeRetriever(pinecone_api, pinecone_index, pinecone_namespace) self.use_rag = True # RAG enabled by default logging.info("ChatbotUI initialized successfully with RAG enabled.") def chatbot_response(self, input_text): """Generate response using retrieved context and the trained model.""" # Retrieve relevant context from Pinecone retrieved_docs = self.retriever.retrieve_context(input_text, top_k=1) preset = ("You are a HR legal assistant. You do not respond as 'User' or pretend to be 'User'. " "You only respond once as 'Assistant'. Avoid Yes or No.") latest_prompt = (f"{preset}\n\n### Context: {retrieved_docs.strip()}\n\n" f"### User: {input_text.strip()}\n\n### Response:") inputs = self.trainer.tokenizer(latest_prompt, return_tensors="pt") outputs = self.trainer.model.generate( **inputs, max_new_tokens=2000, # Reduce token size to optimize speed do_sample=True, top_p=0.95, temperature=0.1, repetition_penalty=1.2, ) response = self.trainer.tokenizer.decode(outputs[0], skip_special_tokens=True).strip() if response.startswith(latest_prompt): response = response[len(latest_prompt):].strip() return response def clear_conversation(self): """Clears the entire conversation history and resets the input box.""" logging.info("Clearing conversation history.") return [], "", [] def launch(self): logging.info("Launching chatbot UI...") with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("""
An intelligent HR legal assistant powered by AI.
""") with gr.Row(): chatbot = gr.Chatbot(label="Chat", height=400) with gr.Row(): msg = gr.Textbox(label="Your Message", placeholder="Enter your message here...", lines=2) send_btn = gr.Button("Send", variant="primary") with gr.Row(): clear_btn = gr.Button("Clear Chat", variant="secondary") state = gr.State([]) def user_message(message, history): response = self.chatbot_response(message) history = history + [(message, response)] # Return updated conversation history, clear the textbox, and update the state return history, "", history send_btn.click(user_message, inputs=[msg, state], outputs=[chatbot, msg, state]) clear_btn.click(self.clear_conversation, inputs=[], outputs=[chatbot, msg, state], queue=False) demo.launch() if __name__ == "__main__": chatbot = ChatbotUI( model_name="sainoforce/modelv5", pinecone_api=os.getenv("PINECONE_API_KEY"), pinecone_index=os.getenv("PINECONE_INDEX"), pinecone_namespace=os.getenv("PINECONE_NAMESPACE") ) chatbot.launch()