File size: 3,835 Bytes
db7016c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10f44c3
db7016c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10f44c3
db7016c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e8bfb7
db7016c
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gradio as gr
import logging
from backend.train import ModelTrainer
from backend.rag import PineconeRetriever
from dotenv import load_dotenv
import os

load_dotenv()

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class ChatbotUI:
    """Class for the Gradio-based chatbot UI with RAG enabled by default."""
    def __init__(self, model_name, pinecone_api, pinecone_index, pinecone_namespace):
        logging.info("Initializing ChatbotUI...")
        self.trainer = ModelTrainer(model_name)
        self.retriever = PineconeRetriever(pinecone_api, pinecone_index, pinecone_namespace)
        self.use_rag = True  # RAG enabled by default
        logging.info("ChatbotUI initialized successfully with RAG enabled.")
    
    def chatbot_response(self, input_text):
        """Generate response using retrieved context and the trained model."""
        # Retrieve relevant context from Pinecone
        retrieved_docs = self.retriever.retrieve_context(input_text, top_k=1)
        
        preset = ("You are a HR legal assistant. You do not respond as 'User' or pretend to be 'User'. "
                  "You only respond once as 'Assistant'. Avoid Yes or No.")
        latest_prompt = (f"{preset}\n\n### Context: {retrieved_docs.strip()}\n\n"
                         f"### User: {input_text.strip()}\n\n### Response:")

        inputs = self.trainer.tokenizer(latest_prompt, return_tensors="pt")
        outputs = self.trainer.model.generate(
            **inputs,
            max_new_tokens=2000,  # Reduce token size to optimize speed
            do_sample=True,
            top_p=0.95,
            temperature=0.1,
            repetition_penalty=1.2,
        )

        response = self.trainer.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
        if response.startswith(latest_prompt):
            response = response[len(latest_prompt):].strip()
        return response

    def clear_conversation(self):
        """Clears the entire conversation history and resets the input box."""
        logging.info("Clearing conversation history.")
        return [], "", []
    
    def launch(self):
        logging.info("Launching chatbot UI...")
        with gr.Blocks(theme=gr.themes.Soft()) as demo:
            gr.Markdown("""
            <h1 style='text-align: center; color: #2E3A59;'>HAP Chatbot</h1>
            <p style='text-align: center; color: #4A5568;'>An intelligent HR legal assistant powered by AI.</p>
            """)
            
            with gr.Row():
                chatbot = gr.Chatbot(label="Chat", height=400)
            
            with gr.Row():
                msg = gr.Textbox(label="Your Message", placeholder="Enter your message here...", lines=2)
                send_btn = gr.Button("Send", variant="primary")
            
            with gr.Row():
                clear_btn = gr.Button("Clear Chat", variant="secondary")
            
            state = gr.State([])

            def user_message(message, history):
                response = self.chatbot_response(message)
                history = history + [(message, response)]
                # Return updated conversation history, clear the textbox, and update the state
                return history, "", history
            
            send_btn.click(user_message, inputs=[msg, state], outputs=[chatbot, msg, state])
            clear_btn.click(self.clear_conversation, inputs=[], outputs=[chatbot, msg, state], queue=False)
            
        demo.launch()

if __name__ == "__main__":
    chatbot = ChatbotUI(
        model_name="sainoforce/modelv5",
        pinecone_api=os.getenv("PINECONE_API_KEY"),
        pinecone_index=os.getenv("PINECONE_INDEX"),
        pinecone_namespace=os.getenv("PINECONE_NAMESPACE")
    )
    chatbot.launch()