from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification
import streamlit as st
from llama_cpp import Llama
from dotenv import load_dotenv
import os
import torch
load_dotenv()
try:
REPO_ID = os.getenv('REPO_ID')
INTENT_MODEL = os.getenv('INTENT_MODEL')
CHITCHAT_MODEL = os.getenv('CHITCHAT_MODEL')
if not all([REPO_ID, INTENT_MODEL, CHITCHAT_MODEL]):
raise EnvironmentError("One or more required environment variables are missing.")
except Exception as e:
st.error(f"Environment setup failed: {e}")
raise SystemExit(e)
try:
llm_hr = Llama.from_pretrained(
repo_id=REPO_ID,
filename="unsloth.F16.gguf"
)
intent_tokenizer = AutoTokenizer.from_pretrained(INTENT_MODEL)
intent_model = AutoModelForSequenceClassification.from_pretrained(INTENT_MODEL)
tokenizer = AutoTokenizer.from_pretrained(CHITCHAT_MODEL)
model = AutoModelForSeq2SeqLM.from_pretrained(CHITCHAT_MODEL)
pipe_chitchat = pipeline('text2text-generation', model=model, tokenizer=tokenizer)
except Exception as e:
st.error(f"Model loading failed: {e}")
raise SystemExit(e)
intent_labels = ["Employee Benefits & Policies", "Employee Support & Self-Service", "Recruitment & Onboarding"]
total_labels = ["😂 Let's Chat & Chill"] + intent_labels
st.set_page_config(page_title="AI HR Chatbot", layout="wide")
st.markdown("""
""", unsafe_allow_html=True)
if "intent" not in st.session_state:
st.session_state.intent = total_labels[0]
if "messages" not in st.session_state:
st.session_state.messages = []
def select_intent(current_intent):
display_labels = [f"✅ {intent}" if intent == current_intent else intent for intent in total_labels]
selected = st.sidebar.radio(
"💼 Select Your HR Intent",
display_labels,
index=display_labels.index(f"✅ {current_intent}") if f"✅ {current_intent}" in display_labels else 0
)
return selected.replace("✅ ", "")
def determine_intent(text):
try:
inputs = intent_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
logits = intent_model(**inputs).logits
predicted_label = torch.argmax(logits, dim=1).item()
return predicted_label
except Exception as e:
st.error(f"Intent determination failed: {e}")
return 0
def render_chat():
if st.session_state.messages:
st.markdown('
', unsafe_allow_html=True)
for msg in st.session_state.messages:
role_class = "user-msg" if msg["role"] == "user" else "bot-msg"
with st.chat_message(msg["role"], avatar=msg.get("avatar")):
st.markdown(f'
{msg["content"]}
', unsafe_allow_html=True)
st.markdown('
', unsafe_allow_html=True)
def generate_hr_response(user_prompt, intent):
hr_prompt = """You are an HR Assistant at our company.
Your role is to assist employees by providing accurate and concise responses regarding company policies and HR-related questions.
### Instruction:
{}
### Input:
{}
### Response:
{}"""
instruction = f"Answer the HR-related query categorized as '{intent}'."
formatted_prompt = hr_prompt.format(instruction, user_prompt, "") + tokenizer.eos_token
try:
response = llm_hr.create_chat_completion(
messages=[{"role": "user", "content": formatted_prompt}],
max_tokens=100,
)
return response["choices"][0]["message"]["content"]
except Exception as e:
st.error(f"Failed to generate HR response: {e}")
return "⚠️ Sorry, I couldn't process your HR request at the moment."
st.markdown('AI HR Chatbot
', unsafe_allow_html=True)
st.markdown("Your personal assistant for HR queries and support.
", unsafe_allow_html=True)
st.session_state.intent = select_intent(st.session_state.intent)
render_chat()
try:
if prompt := st.chat_input("Ask me anything related to HR or just chat casually..."):
st.session_state.messages.append({"role": "user", "content": prompt, "avatar": "man.png"})
with st.chat_message("user", avatar="man.png"):
st.markdown(f'{prompt}
', unsafe_allow_html=True)
if st.session_state.intent == "😂 Let's Chat & Chill":
with st.spinner("😄 Chitchatting..."):
try:
result = pipe_chitchat(prompt)
reply = result[0]['generated_text'].strip()
except Exception as e:
st.error(f"Chitchat generation failed: {e}")
reply = "Oops, I couldn't come up with a witty reply! 😅"
else:
detected_intent = determine_intent(prompt)
predicted_label = intent_labels[detected_intent]
if predicted_label != st.session_state.intent:
st.warning(f"🔁 Automatically switched to **{predicted_label}** based on your query.")
with st.spinner("Generating response with specialized HR model..."):
reply = generate_hr_response(prompt, predicted_label)
st.session_state.messages.append({"role": "bot", "content": reply, "avatar": "robot.png"})
with st.chat_message("bot", avatar="robot.png"):
st.markdown(f'{reply}
', unsafe_allow_html=True)
except Exception as e:
st.error(f"Something went wrong while processing your input: {e}")