jiya2's picture
Upload app.py
0eb943c verified
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification
import streamlit as st
from llama_cpp import Llama
from dotenv import load_dotenv
import os
import torch
load_dotenv()
try:
REPO_ID = os.getenv('REPO_ID')
INTENT_MODEL = os.getenv('INTENT_MODEL')
CHITCHAT_MODEL = os.getenv('CHITCHAT_MODEL')
if not all([REPO_ID, INTENT_MODEL, CHITCHAT_MODEL]):
raise EnvironmentError("One or more required environment variables are missing.")
except Exception as e:
st.error(f"Environment setup failed: {e}")
raise SystemExit(e)
try:
llm_hr = Llama.from_pretrained(
repo_id=REPO_ID,
filename="unsloth.F16.gguf"
)
intent_tokenizer = AutoTokenizer.from_pretrained(INTENT_MODEL)
intent_model = AutoModelForSequenceClassification.from_pretrained(INTENT_MODEL)
tokenizer = AutoTokenizer.from_pretrained(CHITCHAT_MODEL)
model = AutoModelForSeq2SeqLM.from_pretrained(CHITCHAT_MODEL)
pipe_chitchat = pipeline('text2text-generation', model=model, tokenizer=tokenizer)
except Exception as e:
st.error(f"Model loading failed: {e}")
raise SystemExit(e)
intent_labels = ["Employee Benefits & Policies", "Employee Support & Self-Service", "Recruitment & Onboarding"]
total_labels = ["πŸ˜‚ Let's Chat & Chill"] + intent_labels
st.set_page_config(page_title="AI HR Chatbot", layout="wide")
st.markdown("""
<style>
body { background-color: #f6f9fc; }
.sidebar .sidebar-content { background-color: #ffffff; }
.chat-container { max-height: 500px; overflow-y: auto; padding: 1rem; display: flex; flex-direction: column; }
.chat-box { display: inline-block; border-radius: 10px; padding: 8px 12px; margin: 5px 0; max-width: 70%; word-wrap: break-word; }
.user-msg { background-color: #708090; text-align: right; align-self: flex-end; color: white; }
.bot-msg { background-color: #d3d3d3; text-align: left; align-self: flex-start; color: black; }
div.title { font-size: 2rem; font-weight: bold; margin: 20px 0; color: #0077cc; }
</style>
""", unsafe_allow_html=True)
if "intent" not in st.session_state:
st.session_state.intent = total_labels[0]
if "messages" not in st.session_state:
st.session_state.messages = []
def select_intent(current_intent):
display_labels = [f"βœ… {intent}" if intent == current_intent else intent for intent in total_labels]
selected = st.sidebar.radio(
"πŸ’Ό Select Your HR Intent",
display_labels,
index=display_labels.index(f"βœ… {current_intent}") if f"βœ… {current_intent}" in display_labels else 0
)
return selected.replace("βœ… ", "")
def determine_intent(text):
try:
inputs = intent_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
logits = intent_model(**inputs).logits
predicted_label = torch.argmax(logits, dim=1).item()
return predicted_label
except Exception as e:
st.error(f"Intent determination failed: {e}")
return 0
def render_chat():
if st.session_state.messages:
st.markdown('<div class="chat-container">', unsafe_allow_html=True)
for msg in st.session_state.messages:
role_class = "user-msg" if msg["role"] == "user" else "bot-msg"
with st.chat_message(msg["role"], avatar=msg.get("avatar")):
st.markdown(f'<div class="chat-box {role_class}">{msg["content"]}</div>', unsafe_allow_html=True)
st.markdown('</div>', unsafe_allow_html=True)
def generate_hr_response(user_prompt, intent):
hr_prompt = """You are an HR Assistant at our company.
Your role is to assist employees by providing accurate and concise responses regarding company policies and HR-related questions.
### Instruction:
{}
### Input:
{}
### Response:
{}"""
instruction = f"Answer the HR-related query categorized as '{intent}'."
formatted_prompt = hr_prompt.format(instruction, user_prompt, "") + tokenizer.eos_token
try:
response = llm_hr.create_chat_completion(
messages=[{"role": "user", "content": formatted_prompt}],
max_tokens=100,
)
return response["choices"][0]["message"]["content"]
except Exception as e:
st.error(f"Failed to generate HR response: {e}")
return "⚠️ Sorry, I couldn't process your HR request at the moment."
st.markdown('<div class="title">AI HR Chatbot </div>', unsafe_allow_html=True)
st.markdown("<h6 style='color:rgb(131, 123, 160);'>Your personal assistant for HR queries and support.</h6>", unsafe_allow_html=True)
st.session_state.intent = select_intent(st.session_state.intent)
render_chat()
try:
if prompt := st.chat_input("Ask me anything related to HR or just chat casually..."):
st.session_state.messages.append({"role": "user", "content": prompt, "avatar": "man.png"})
with st.chat_message("user", avatar="man.png"):
st.markdown(f'<div class="chat-box user-msg">{prompt}</div>', unsafe_allow_html=True)
if st.session_state.intent == "πŸ˜‚ Let's Chat & Chill":
with st.spinner("πŸ˜„ Chitchatting..."):
try:
result = pipe_chitchat(prompt)
reply = result[0]['generated_text'].strip()
except Exception as e:
st.error(f"Chitchat generation failed: {e}")
reply = "Oops, I couldn't come up with a witty reply! πŸ˜…"
else:
detected_intent = determine_intent(prompt)
predicted_label = intent_labels[detected_intent]
if predicted_label != st.session_state.intent:
st.warning(f"πŸ” Automatically switched to **{predicted_label}** based on your query.")
with st.spinner("Generating response with specialized HR model..."):
reply = generate_hr_response(prompt, predicted_label)
st.session_state.messages.append({"role": "bot", "content": reply, "avatar": "robot.png"})
with st.chat_message("bot", avatar="robot.png"):
st.markdown(f'<div class="chat-box bot-msg">{reply}</div>', unsafe_allow_html=True)
except Exception as e:
st.error(f"Something went wrong while processing your input: {e}")