Spaces:
Sleeping
Sleeping
| from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification | |
| import streamlit as st | |
| from llama_cpp import Llama | |
| from dotenv import load_dotenv | |
| import os | |
| import torch | |
| load_dotenv() | |
| try: | |
| REPO_ID = os.getenv('REPO_ID') | |
| INTENT_MODEL = os.getenv('INTENT_MODEL') | |
| CHITCHAT_MODEL = os.getenv('CHITCHAT_MODEL') | |
| if not all([REPO_ID, INTENT_MODEL, CHITCHAT_MODEL]): | |
| raise EnvironmentError("One or more required environment variables are missing.") | |
| except Exception as e: | |
| st.error(f"Environment setup failed: {e}") | |
| raise SystemExit(e) | |
| try: | |
| llm_hr = Llama.from_pretrained( | |
| repo_id=REPO_ID, | |
| filename="unsloth.F16.gguf" | |
| ) | |
| intent_tokenizer = AutoTokenizer.from_pretrained(INTENT_MODEL) | |
| intent_model = AutoModelForSequenceClassification.from_pretrained(INTENT_MODEL) | |
| tokenizer = AutoTokenizer.from_pretrained(CHITCHAT_MODEL) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(CHITCHAT_MODEL) | |
| pipe_chitchat = pipeline('text2text-generation', model=model, tokenizer=tokenizer) | |
| except Exception as e: | |
| st.error(f"Model loading failed: {e}") | |
| raise SystemExit(e) | |
| intent_labels = ["Employee Benefits & Policies", "Employee Support & Self-Service", "Recruitment & Onboarding"] | |
| total_labels = ["π Let's Chat & Chill"] + intent_labels | |
| st.set_page_config(page_title="AI HR Chatbot", layout="wide") | |
| st.markdown(""" | |
| <style> | |
| body { background-color: #f6f9fc; } | |
| .sidebar .sidebar-content { background-color: #ffffff; } | |
| .chat-container { max-height: 500px; overflow-y: auto; padding: 1rem; display: flex; flex-direction: column; } | |
| .chat-box { display: inline-block; border-radius: 10px; padding: 8px 12px; margin: 5px 0; max-width: 70%; word-wrap: break-word; } | |
| .user-msg { background-color: #708090; text-align: right; align-self: flex-end; color: white; } | |
| .bot-msg { background-color: #d3d3d3; text-align: left; align-self: flex-start; color: black; } | |
| div.title { font-size: 2rem; font-weight: bold; margin: 20px 0; color: #0077cc; } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| if "intent" not in st.session_state: | |
| st.session_state.intent = total_labels[0] | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| def select_intent(current_intent): | |
| display_labels = [f"β {intent}" if intent == current_intent else intent for intent in total_labels] | |
| selected = st.sidebar.radio( | |
| "πΌ Select Your HR Intent", | |
| display_labels, | |
| index=display_labels.index(f"β {current_intent}") if f"β {current_intent}" in display_labels else 0 | |
| ) | |
| return selected.replace("β ", "") | |
| def determine_intent(text): | |
| try: | |
| inputs = intent_tokenizer(text, return_tensors="pt", truncation=True, padding=True) | |
| with torch.no_grad(): | |
| logits = intent_model(**inputs).logits | |
| predicted_label = torch.argmax(logits, dim=1).item() | |
| return predicted_label | |
| except Exception as e: | |
| st.error(f"Intent determination failed: {e}") | |
| return 0 | |
| def render_chat(): | |
| if st.session_state.messages: | |
| st.markdown('<div class="chat-container">', unsafe_allow_html=True) | |
| for msg in st.session_state.messages: | |
| role_class = "user-msg" if msg["role"] == "user" else "bot-msg" | |
| with st.chat_message(msg["role"], avatar=msg.get("avatar")): | |
| st.markdown(f'<div class="chat-box {role_class}">{msg["content"]}</div>', unsafe_allow_html=True) | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| def generate_hr_response(user_prompt, intent): | |
| hr_prompt = """You are an HR Assistant at our company. | |
| Your role is to assist employees by providing accurate and concise responses regarding company policies and HR-related questions. | |
| ### Instruction: | |
| {} | |
| ### Input: | |
| {} | |
| ### Response: | |
| {}""" | |
| instruction = f"Answer the HR-related query categorized as '{intent}'." | |
| formatted_prompt = hr_prompt.format(instruction, user_prompt, "") + tokenizer.eos_token | |
| try: | |
| response = llm_hr.create_chat_completion( | |
| messages=[{"role": "user", "content": formatted_prompt}], | |
| max_tokens=100, | |
| ) | |
| return response["choices"][0]["message"]["content"] | |
| except Exception as e: | |
| st.error(f"Failed to generate HR response: {e}") | |
| return "β οΈ Sorry, I couldn't process your HR request at the moment." | |
| st.markdown('<div class="title">AI HR Chatbot </div>', unsafe_allow_html=True) | |
| st.markdown("<h6 style='color:rgb(131, 123, 160);'>Your personal assistant for HR queries and support.</h6>", unsafe_allow_html=True) | |
| st.session_state.intent = select_intent(st.session_state.intent) | |
| render_chat() | |
| try: | |
| if prompt := st.chat_input("Ask me anything related to HR or just chat casually..."): | |
| st.session_state.messages.append({"role": "user", "content": prompt, "avatar": "man.png"}) | |
| with st.chat_message("user", avatar="man.png"): | |
| st.markdown(f'<div class="chat-box user-msg">{prompt}</div>', unsafe_allow_html=True) | |
| if st.session_state.intent == "π Let's Chat & Chill": | |
| with st.spinner("π Chitchatting..."): | |
| try: | |
| result = pipe_chitchat(prompt) | |
| reply = result[0]['generated_text'].strip() | |
| except Exception as e: | |
| st.error(f"Chitchat generation failed: {e}") | |
| reply = "Oops, I couldn't come up with a witty reply! π " | |
| else: | |
| detected_intent = determine_intent(prompt) | |
| predicted_label = intent_labels[detected_intent] | |
| if predicted_label != st.session_state.intent: | |
| st.warning(f"π Automatically switched to **{predicted_label}** based on your query.") | |
| with st.spinner("Generating response with specialized HR model..."): | |
| reply = generate_hr_response(prompt, predicted_label) | |
| st.session_state.messages.append({"role": "bot", "content": reply, "avatar": "robot.png"}) | |
| with st.chat_message("bot", avatar="robot.png"): | |
| st.markdown(f'<div class="chat-box bot-msg">{reply}</div>', unsafe_allow_html=True) | |
| except Exception as e: | |
| st.error(f"Something went wrong while processing your input: {e}") | |