SF001-123456 commited on
Commit
db7016c
·
verified ·
1 Parent(s): 10f44c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -56
app.py CHANGED
@@ -1,59 +1,90 @@
1
- def generate_response(prompt):
2
- preset = ("You are a HR legal assistant. You do not respond as 'User' or pretend to be 'User'. "
3
- "You only respond once as 'Assistant'. Avoid Yes or No.")
4
- # Try to get additional context from query_pinecone if defined
5
- try:
6
- context = query_pinecone(prompt).strip()
7
- except NameError:
8
- context = ""
9
-
10
- latest_prompt = f"{preset}\n\n### Context: {context}\n\n### User: {prompt}\n\n### Response:"
11
-
12
- inputs = tokenizer(latest_prompt, return_tensors="pt").to(model.device)
13
- outputs = model.generate(
14
- **inputs,
15
- max_new_tokens=2000,
16
- do_sample=True,
17
- top_p=0.95,
18
- temperature=0.1,
19
- repetition_penalty=1.2,
20
- )
21
- response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
22
- if response.startswith(latest_prompt):
23
- response = response[len(latest_prompt):].strip()
24
-
25
- return response
26
-
27
- def user_message(message, history):
28
- response = generate_response(message)
29
- # Append the new conversation to history as a tuple: (user_message, assistant_response)
30
- history = history + [(message, response)]
31
- # Return the updated history and clear the textbox (by returning an empty string)
32
- return history, ""
33
-
34
- def clear_conversation():
35
- # Clear the chat history and message textbox
36
- return [], "", []
37
-
38
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
39
- gr.Markdown("""
40
- <h1 style='text-align: center; color: #2E3A59;'>HAP Chatbot</h1>
41
- <p style='text-align: center; color: #4A5568;'>An intelligent HR legal assistant powered by AI.</p>
42
- """)
43
-
44
- with gr.Row():
45
- chatbot = gr.Chatbot(label="Chat", height=400)
46
-
47
- with gr.Row():
48
- msg = gr.Textbox(label="Your Message", placeholder="Enter your message here...", lines=2)
49
- send_btn = gr.Button("Send", variant="primary")
50
-
51
- with gr.Row():
52
- clear_btn = gr.Button("Clear Chat", variant="secondary")
53
 
54
- state = gr.State([]) # Initialize conversation history as an empty list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- send_btn.click(user_message, inputs=[msg, state], outputs=[chatbot, msg])
57
- clear_btn.click(clear_conversation, outputs=[chatbot, msg, state], queue=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- demo.launch()
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import logging
3
+ from backend.train import ModelTrainer
4
+ from backend.rag import PineconeRetriever
5
+ from dotenv import load_dotenv
6
+ import os
7
+
8
+ load_dotenv()
9
+
10
+ # Configure logging
11
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
12
+
13
+ class ChatbotUI:
14
+ """Class for the Gradio-based chatbot UI with RAG enabled by default."""
15
+ def __init__(self, model_name, pinecone_api, pinecone_index, pinecone_namespace):
16
+ logging.info("Initializing ChatbotUI...")
17
+ self.trainer = ModelTrainer(model_name)
18
+ self.retriever = PineconeRetriever(pinecone_api, pinecone_index, pinecone_namespace)
19
+ self.use_rag = True # RAG enabled by default
20
+ logging.info("ChatbotUI initialized successfully with RAG enabled.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ def chatbot_response(self, input_text):
23
+ """Generate response using retrieved context and the trained model."""
24
+ # Retrieve relevant context from Pinecone
25
+ retrieved_docs = self.retriever.retrieve_context(input_text, top_k=1)
26
+
27
+ preset = ("You are a HR legal assistant. You do not respond as 'User' or pretend to be 'User'. "
28
+ "You only respond once as 'Assistant'. Avoid Yes or No.")
29
+ latest_prompt = (f"{preset}\n\n### Context: {retrieved_docs.strip()}\n\n"
30
+ f"### User: {input_text.strip()}\n\n### Response:")
31
+
32
+ inputs = self.trainer.tokenizer(latest_prompt, return_tensors="pt")
33
+ outputs = self.trainer.model.generate(
34
+ **inputs,
35
+ max_new_tokens=2000, # Reduce token size to optimize speed
36
+ do_sample=True,
37
+ top_p=0.95,
38
+ temperature=0.1,
39
+ repetition_penalty=1.2,
40
+ )
41
+
42
+ response = self.trainer.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
43
+ if response.startswith(latest_prompt):
44
+ response = response[len(latest_prompt):].strip()
45
+ return response
46
+
47
+ def clear_conversation(self):
48
+ """Clears the entire conversation history and resets the input box."""
49
+ logging.info("Clearing conversation history.")
50
+ return [], "", []
51
 
52
+ def launch(self):
53
+ logging.info("Launching chatbot UI...")
54
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
55
+ gr.Markdown("""
56
+ <h1 style='text-align: center; color: #2E3A59;'>HAP Chatbot</h1>
57
+ <p style='text-align: center; color: #4A5568;'>An intelligent HR legal assistant powered by AI.</p>
58
+ """)
59
+
60
+ with gr.Row():
61
+ chatbot = gr.Chatbot(label="Chat", height=400)
62
+
63
+ with gr.Row():
64
+ msg = gr.Textbox(label="Your Message", placeholder="Enter your message here...", lines=2)
65
+ send_btn = gr.Button("Send", variant="primary")
66
+
67
+ with gr.Row():
68
+ clear_btn = gr.Button("Clear Chat", variant="secondary")
69
+
70
+ state = gr.State([])
71
+
72
+ def user_message(message, history):
73
+ response = self.chatbot_response(message)
74
+ history = history + [(message, response)]
75
+ # Return updated conversation history, clear the textbox, and update the state
76
+ return history, "", history
77
+
78
+ send_btn.click(user_message, inputs=[msg, state], outputs=[chatbot, msg, state])
79
+ clear_btn.click(self.clear_conversation, inputs=[], outputs=[chatbot, msg, state], queue=False)
80
+
81
+ demo.launch()
82
 
83
+ if __name__ == "__main__":
84
+ chatbot = ChatbotUI(
85
+ model_name="sainoforce/modelv5",
86
+ pinecone_api=os.getenv("PINECONE_API_KEY"),
87
+ pinecone_index=os.getenv("PINECONE_INDEX"),
88
+ pinecone_namespace=os.getenv("PINECONE_NAMESPACE")
89
+ )
90
+ chatbot.launch()