File size: 6,469 Bytes
0eb943c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification
import streamlit as st
from llama_cpp import Llama
from dotenv import load_dotenv
import os
import torch

load_dotenv()

try:
    REPO_ID = os.getenv('REPO_ID')
    INTENT_MODEL = os.getenv('INTENT_MODEL')
    CHITCHAT_MODEL = os.getenv('CHITCHAT_MODEL')

    if not all([REPO_ID, INTENT_MODEL, CHITCHAT_MODEL]):
        raise EnvironmentError("One or more required environment variables are missing.")

except Exception as e:
    st.error(f"Environment setup failed: {e}")
    raise SystemExit(e)

try:
    llm_hr = Llama.from_pretrained(
        repo_id=REPO_ID,
        filename="unsloth.F16.gguf"
    )
    intent_tokenizer = AutoTokenizer.from_pretrained(INTENT_MODEL)
    intent_model = AutoModelForSequenceClassification.from_pretrained(INTENT_MODEL)
    tokenizer = AutoTokenizer.from_pretrained(CHITCHAT_MODEL)
    model = AutoModelForSeq2SeqLM.from_pretrained(CHITCHAT_MODEL)
    pipe_chitchat = pipeline('text2text-generation', model=model, tokenizer=tokenizer)

except Exception as e:
    st.error(f"Model loading failed: {e}")
    raise SystemExit(e)


intent_labels = ["Employee Benefits & Policies", "Employee Support & Self-Service", "Recruitment & Onboarding"]
total_labels = ["πŸ˜‚ Let's Chat & Chill"] + intent_labels


st.set_page_config(page_title="AI HR Chatbot", layout="wide")
st.markdown("""

<style>

body { background-color: #f6f9fc; }

.sidebar .sidebar-content { background-color: #ffffff; }

.chat-container { max-height: 500px; overflow-y: auto; padding: 1rem; display: flex; flex-direction: column; }

.chat-box { display: inline-block; border-radius: 10px; padding: 8px 12px; margin: 5px 0; max-width: 70%; word-wrap: break-word; }

.user-msg { background-color: #708090; text-align: right; align-self: flex-end; color: white; }

.bot-msg { background-color: #d3d3d3; text-align: left; align-self: flex-start; color: black; }

div.title { font-size: 2rem; font-weight: bold; margin: 20px 0; color: #0077cc; }

</style>

""", unsafe_allow_html=True)

if "intent" not in st.session_state:
    st.session_state.intent = total_labels[0]
if "messages" not in st.session_state:
    st.session_state.messages = []


def select_intent(current_intent):
    display_labels = [f"βœ… {intent}" if intent == current_intent else intent for intent in total_labels]
    selected = st.sidebar.radio(
        "πŸ’Ό Select Your HR Intent",
        display_labels,
        index=display_labels.index(f"βœ… {current_intent}") if f"βœ… {current_intent}" in display_labels else 0
    )
    return selected.replace("βœ… ", "")


def determine_intent(text):
    try:
        inputs = intent_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
        with torch.no_grad():
            logits = intent_model(**inputs).logits
        predicted_label = torch.argmax(logits, dim=1).item()
        return predicted_label
    except Exception as e:
        st.error(f"Intent determination failed: {e}")
        return 0  


def render_chat():
    if st.session_state.messages:
        st.markdown('<div class="chat-container">', unsafe_allow_html=True)
        for msg in st.session_state.messages:
            role_class = "user-msg" if msg["role"] == "user" else "bot-msg"
            with st.chat_message(msg["role"], avatar=msg.get("avatar")):
                st.markdown(f'<div class="chat-box {role_class}">{msg["content"]}</div>', unsafe_allow_html=True)
        st.markdown('</div>', unsafe_allow_html=True)


def generate_hr_response(user_prompt, intent):
    hr_prompt = """You are an HR Assistant at our company.

    Your role is to assist employees by providing accurate and concise responses regarding company policies and HR-related questions.



    ### Instruction:

    {}



    ### Input:

    {}



    ### Response:

    {}"""
    instruction = f"Answer the HR-related query categorized as '{intent}'."
    formatted_prompt = hr_prompt.format(instruction, user_prompt, "") + tokenizer.eos_token

    try:
        response = llm_hr.create_chat_completion(
            messages=[{"role": "user", "content": formatted_prompt}],
            max_tokens=100,
        )
        return response["choices"][0]["message"]["content"]
    except Exception as e:
        st.error(f"Failed to generate HR response: {e}")
        return "⚠️ Sorry, I couldn't process your HR request at the moment."


st.markdown('<div class="title">AI HR Chatbot </div>', unsafe_allow_html=True)
st.markdown("<h6 style='color:rgb(131, 123, 160);'>Your personal assistant for HR queries and support.</h6>", unsafe_allow_html=True)

st.session_state.intent = select_intent(st.session_state.intent)
render_chat()


try:
    if prompt := st.chat_input("Ask me anything related to HR or just chat casually..."):
        st.session_state.messages.append({"role": "user", "content": prompt, "avatar": "man.png"})
        with st.chat_message("user", avatar="man.png"):
            st.markdown(f'<div class="chat-box user-msg">{prompt}</div>', unsafe_allow_html=True)

        if st.session_state.intent == "πŸ˜‚ Let's Chat & Chill":
            with st.spinner("πŸ˜„ Chitchatting..."):
                try:
                    result = pipe_chitchat(prompt)
                    reply = result[0]['generated_text'].strip() 
                except Exception as e:
                    st.error(f"Chitchat generation failed: {e}")
                    reply = "Oops, I couldn't come up with a witty reply! πŸ˜…"
        else:
            detected_intent = determine_intent(prompt)
            predicted_label = intent_labels[detected_intent]

            if predicted_label != st.session_state.intent:
                
                st.warning(f"πŸ” Automatically switched to **{predicted_label}** based on your query.")

            with st.spinner("Generating response with specialized HR model..."):
                reply = generate_hr_response(prompt, predicted_label)

        st.session_state.messages.append({"role": "bot", "content": reply, "avatar": "robot.png"})
        with st.chat_message("bot", avatar="robot.png"):
            st.markdown(f'<div class="chat-box bot-msg">{reply}</div>', unsafe_allow_html=True)

except Exception as e:
    st.error(f"Something went wrong while processing your input: {e}")