Spaces:
Sleeping
Sleeping
File size: 6,469 Bytes
0eb943c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification
import streamlit as st
from llama_cpp import Llama
from dotenv import load_dotenv
import os
import torch
load_dotenv()
try:
REPO_ID = os.getenv('REPO_ID')
INTENT_MODEL = os.getenv('INTENT_MODEL')
CHITCHAT_MODEL = os.getenv('CHITCHAT_MODEL')
if not all([REPO_ID, INTENT_MODEL, CHITCHAT_MODEL]):
raise EnvironmentError("One or more required environment variables are missing.")
except Exception as e:
st.error(f"Environment setup failed: {e}")
raise SystemExit(e)
try:
llm_hr = Llama.from_pretrained(
repo_id=REPO_ID,
filename="unsloth.F16.gguf"
)
intent_tokenizer = AutoTokenizer.from_pretrained(INTENT_MODEL)
intent_model = AutoModelForSequenceClassification.from_pretrained(INTENT_MODEL)
tokenizer = AutoTokenizer.from_pretrained(CHITCHAT_MODEL)
model = AutoModelForSeq2SeqLM.from_pretrained(CHITCHAT_MODEL)
pipe_chitchat = pipeline('text2text-generation', model=model, tokenizer=tokenizer)
except Exception as e:
st.error(f"Model loading failed: {e}")
raise SystemExit(e)
intent_labels = ["Employee Benefits & Policies", "Employee Support & Self-Service", "Recruitment & Onboarding"]
total_labels = ["π Let's Chat & Chill"] + intent_labels
st.set_page_config(page_title="AI HR Chatbot", layout="wide")
st.markdown("""
<style>
body { background-color: #f6f9fc; }
.sidebar .sidebar-content { background-color: #ffffff; }
.chat-container { max-height: 500px; overflow-y: auto; padding: 1rem; display: flex; flex-direction: column; }
.chat-box { display: inline-block; border-radius: 10px; padding: 8px 12px; margin: 5px 0; max-width: 70%; word-wrap: break-word; }
.user-msg { background-color: #708090; text-align: right; align-self: flex-end; color: white; }
.bot-msg { background-color: #d3d3d3; text-align: left; align-self: flex-start; color: black; }
div.title { font-size: 2rem; font-weight: bold; margin: 20px 0; color: #0077cc; }
</style>
""", unsafe_allow_html=True)
if "intent" not in st.session_state:
st.session_state.intent = total_labels[0]
if "messages" not in st.session_state:
st.session_state.messages = []
def select_intent(current_intent):
display_labels = [f"β
{intent}" if intent == current_intent else intent for intent in total_labels]
selected = st.sidebar.radio(
"πΌ Select Your HR Intent",
display_labels,
index=display_labels.index(f"β
{current_intent}") if f"β
{current_intent}" in display_labels else 0
)
return selected.replace("β
", "")
def determine_intent(text):
try:
inputs = intent_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
logits = intent_model(**inputs).logits
predicted_label = torch.argmax(logits, dim=1).item()
return predicted_label
except Exception as e:
st.error(f"Intent determination failed: {e}")
return 0
def render_chat():
if st.session_state.messages:
st.markdown('<div class="chat-container">', unsafe_allow_html=True)
for msg in st.session_state.messages:
role_class = "user-msg" if msg["role"] == "user" else "bot-msg"
with st.chat_message(msg["role"], avatar=msg.get("avatar")):
st.markdown(f'<div class="chat-box {role_class}">{msg["content"]}</div>', unsafe_allow_html=True)
st.markdown('</div>', unsafe_allow_html=True)
def generate_hr_response(user_prompt, intent):
hr_prompt = """You are an HR Assistant at our company.
Your role is to assist employees by providing accurate and concise responses regarding company policies and HR-related questions.
### Instruction:
{}
### Input:
{}
### Response:
{}"""
instruction = f"Answer the HR-related query categorized as '{intent}'."
formatted_prompt = hr_prompt.format(instruction, user_prompt, "") + tokenizer.eos_token
try:
response = llm_hr.create_chat_completion(
messages=[{"role": "user", "content": formatted_prompt}],
max_tokens=100,
)
return response["choices"][0]["message"]["content"]
except Exception as e:
st.error(f"Failed to generate HR response: {e}")
return "β οΈ Sorry, I couldn't process your HR request at the moment."
st.markdown('<div class="title">AI HR Chatbot </div>', unsafe_allow_html=True)
st.markdown("<h6 style='color:rgb(131, 123, 160);'>Your personal assistant for HR queries and support.</h6>", unsafe_allow_html=True)
st.session_state.intent = select_intent(st.session_state.intent)
render_chat()
try:
if prompt := st.chat_input("Ask me anything related to HR or just chat casually..."):
st.session_state.messages.append({"role": "user", "content": prompt, "avatar": "man.png"})
with st.chat_message("user", avatar="man.png"):
st.markdown(f'<div class="chat-box user-msg">{prompt}</div>', unsafe_allow_html=True)
if st.session_state.intent == "π Let's Chat & Chill":
with st.spinner("π Chitchatting..."):
try:
result = pipe_chitchat(prompt)
reply = result[0]['generated_text'].strip()
except Exception as e:
st.error(f"Chitchat generation failed: {e}")
reply = "Oops, I couldn't come up with a witty reply! π
"
else:
detected_intent = determine_intent(prompt)
predicted_label = intent_labels[detected_intent]
if predicted_label != st.session_state.intent:
st.warning(f"π Automatically switched to **{predicted_label}** based on your query.")
with st.spinner("Generating response with specialized HR model..."):
reply = generate_hr_response(prompt, predicted_label)
st.session_state.messages.append({"role": "bot", "content": reply, "avatar": "robot.png"})
with st.chat_message("bot", avatar="robot.png"):
st.markdown(f'<div class="chat-box bot-msg">{reply}</div>', unsafe_allow_html=True)
except Exception as e:
st.error(f"Something went wrong while processing your input: {e}")
|