Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import logging | |
| from backend.train import ModelTrainer | |
| from backend.rag import PineconeRetriever | |
| from dotenv import load_dotenv | |
| import os | |
| load_dotenv() | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| class ChatbotUI: | |
| """Class for the Gradio-based chatbot UI with RAG enabled by default.""" | |
| def __init__(self, model_name, pinecone_api, pinecone_index, pinecone_namespace): | |
| logging.info("Initializing ChatbotUI...") | |
| self.trainer = ModelTrainer(model_name) | |
| self.retriever = PineconeRetriever(pinecone_api, pinecone_index, pinecone_namespace) | |
| self.use_rag = True # RAG enabled by default | |
| logging.info("ChatbotUI initialized successfully with RAG enabled.") | |
| def chatbot_response(self, input_text): | |
| """Generate response using retrieved context and the trained model.""" | |
| # Retrieve relevant context from Pinecone | |
| retrieved_docs = self.retriever.retrieve_context(input_text, top_k=1) | |
| preset = ("You are a HR legal assistant. You do not respond as 'User' or pretend to be 'User'. " | |
| "You only respond once as 'Assistant'. Avoid Yes or No.") | |
| latest_prompt = (f"{preset}\n\n### Context: {retrieved_docs.strip()}\n\n" | |
| f"### User: {input_text.strip()}\n\n### Response:") | |
| inputs = self.trainer.tokenizer(latest_prompt, return_tensors="pt") | |
| outputs = self.trainer.model.generate( | |
| **inputs, | |
| max_new_tokens=2000, # Reduce token size to optimize speed | |
| do_sample=True, | |
| top_p=0.95, | |
| temperature=0.1, | |
| repetition_penalty=1.2, | |
| ) | |
| response = self.trainer.tokenizer.decode(outputs[0], skip_special_tokens=True).strip() | |
| if response.startswith(latest_prompt): | |
| response = response[len(latest_prompt):].strip() | |
| return response | |
| def clear_conversation(self): | |
| """Clears the entire conversation history and resets the input box.""" | |
| logging.info("Clearing conversation history.") | |
| return [], "", [] | |
| def launch(self): | |
| logging.info("Launching chatbot UI...") | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| <h1 style='text-align: center; color: #2E3A59;'>HAP Chatbot</h1> | |
| <p style='text-align: center; color: #4A5568;'>An intelligent HR legal assistant powered by AI.</p> | |
| """) | |
| with gr.Row(): | |
| chatbot = gr.Chatbot(label="Chat", height=400) | |
| with gr.Row(): | |
| msg = gr.Textbox(label="Your Message", placeholder="Enter your message here...", lines=2) | |
| send_btn = gr.Button("Send", variant="primary") | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear Chat", variant="secondary") | |
| state = gr.State([]) | |
| def user_message(message, history): | |
| response = self.chatbot_response(message) | |
| history = history + [(message, response)] | |
| # Return updated conversation history, clear the textbox, and update the state | |
| return history, "", history | |
| send_btn.click(user_message, inputs=[msg, state], outputs=[chatbot, msg, state]) | |
| clear_btn.click(self.clear_conversation, inputs=[], outputs=[chatbot, msg, state], queue=False) | |
| demo.launch() | |
| if __name__ == "__main__": | |
| chatbot = ChatbotUI( | |
| model_name="sainoforce/modelv5", | |
| pinecone_api=os.getenv("PINECONE_API_KEY"), | |
| pinecone_index=os.getenv("PINECONE_INDEX"), | |
| pinecone_namespace=os.getenv("PINECONE_NAMESPACE") | |
| ) | |
| chatbot.launch() | |