from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification import streamlit as st from llama_cpp import Llama from dotenv import load_dotenv import os import torch load_dotenv() try: REPO_ID = os.getenv('REPO_ID') INTENT_MODEL = os.getenv('INTENT_MODEL') CHITCHAT_MODEL = os.getenv('CHITCHAT_MODEL') if not all([REPO_ID, INTENT_MODEL, CHITCHAT_MODEL]): raise EnvironmentError("One or more required environment variables are missing.") except Exception as e: st.error(f"Environment setup failed: {e}") raise SystemExit(e) try: llm_hr = Llama.from_pretrained( repo_id=REPO_ID, filename="unsloth.F16.gguf" ) intent_tokenizer = AutoTokenizer.from_pretrained(INTENT_MODEL) intent_model = AutoModelForSequenceClassification.from_pretrained(INTENT_MODEL) tokenizer = AutoTokenizer.from_pretrained(CHITCHAT_MODEL) model = AutoModelForSeq2SeqLM.from_pretrained(CHITCHAT_MODEL) pipe_chitchat = pipeline('text2text-generation', model=model, tokenizer=tokenizer) except Exception as e: st.error(f"Model loading failed: {e}") raise SystemExit(e) intent_labels = ["Employee Benefits & Policies", "Employee Support & Self-Service", "Recruitment & Onboarding"] total_labels = ["😂 Let's Chat & Chill"] + intent_labels st.set_page_config(page_title="AI HR Chatbot", layout="wide") st.markdown(""" """, unsafe_allow_html=True) if "intent" not in st.session_state: st.session_state.intent = total_labels[0] if "messages" not in st.session_state: st.session_state.messages = [] def select_intent(current_intent): display_labels = [f"✅ {intent}" if intent == current_intent else intent for intent in total_labels] selected = st.sidebar.radio( "💼 Select Your HR Intent", display_labels, index=display_labels.index(f"✅ {current_intent}") if f"✅ {current_intent}" in display_labels else 0 ) return selected.replace("✅ ", "") def determine_intent(text): try: inputs = intent_tokenizer(text, return_tensors="pt", truncation=True, padding=True) with torch.no_grad(): logits = intent_model(**inputs).logits predicted_label = torch.argmax(logits, dim=1).item() return predicted_label except Exception as e: st.error(f"Intent determination failed: {e}") return 0 def render_chat(): if st.session_state.messages: st.markdown('
', unsafe_allow_html=True) for msg in st.session_state.messages: role_class = "user-msg" if msg["role"] == "user" else "bot-msg" with st.chat_message(msg["role"], avatar=msg.get("avatar")): st.markdown(f'
{msg["content"]}
', unsafe_allow_html=True) st.markdown('
', unsafe_allow_html=True) def generate_hr_response(user_prompt, intent): hr_prompt = """You are an HR Assistant at our company. Your role is to assist employees by providing accurate and concise responses regarding company policies and HR-related questions. ### Instruction: {} ### Input: {} ### Response: {}""" instruction = f"Answer the HR-related query categorized as '{intent}'." formatted_prompt = hr_prompt.format(instruction, user_prompt, "") + tokenizer.eos_token try: response = llm_hr.create_chat_completion( messages=[{"role": "user", "content": formatted_prompt}], max_tokens=100, ) return response["choices"][0]["message"]["content"] except Exception as e: st.error(f"Failed to generate HR response: {e}") return "⚠️ Sorry, I couldn't process your HR request at the moment." st.markdown('
AI HR Chatbot
', unsafe_allow_html=True) st.markdown("
Your personal assistant for HR queries and support.
", unsafe_allow_html=True) st.session_state.intent = select_intent(st.session_state.intent) render_chat() try: if prompt := st.chat_input("Ask me anything related to HR or just chat casually..."): st.session_state.messages.append({"role": "user", "content": prompt, "avatar": "man.png"}) with st.chat_message("user", avatar="man.png"): st.markdown(f'
{prompt}
', unsafe_allow_html=True) if st.session_state.intent == "😂 Let's Chat & Chill": with st.spinner("😄 Chitchatting..."): try: result = pipe_chitchat(prompt) reply = result[0]['generated_text'].strip() except Exception as e: st.error(f"Chitchat generation failed: {e}") reply = "Oops, I couldn't come up with a witty reply! 😅" else: detected_intent = determine_intent(prompt) predicted_label = intent_labels[detected_intent] if predicted_label != st.session_state.intent: st.warning(f"🔁 Automatically switched to **{predicted_label}** based on your query.") with st.spinner("Generating response with specialized HR model..."): reply = generate_hr_response(prompt, predicted_label) st.session_state.messages.append({"role": "bot", "content": reply, "avatar": "robot.png"}) with st.chat_message("bot", avatar="robot.png"): st.markdown(f'
{reply}
', unsafe_allow_html=True) except Exception as e: st.error(f"Something went wrong while processing your input: {e}")