import os
import re
import requests
from google import genai
import streamlit as st
from langchain.embeddings import HuggingFaceInferenceAPIEmbeddings
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.chains.question_answering import load_qa_chain
from langchain.prompts import PromptTemplate
from langchain_community.vectorstores import FAISS
import scraping_utils
# Initialize Streamlit page config
st.set_page_config(
page_title="SHL Assessment Recommender",
page_icon="๐",
layout="wide",
initial_sidebar_state="expanded"
)
# Custom CSS for dark theme and improved contrast
st.markdown("""
""", unsafe_allow_html=True)
# -----------------------
# Service check functions
# -----------------------
def check_huggingface(hf_token: str, timeout: float = 4.0):
"""
Check HuggingFace Inference availability by querying the model metadata endpoint.
Returns tuple (status_str, is_reachable_bool).
status_str one of: "up", "down", "auth_error", "no_token"
"""
if not hf_token:
return ("no_token", False)
url = "https://api-inference.huggingface.co/models/sentence-transformers/all-MiniLM-L6-v2"
headers = {"Authorization": f"Bearer {hf_token}"}
try:
resp = requests.get(url, headers=headers, timeout=timeout)
if resp.status_code == 200:
return ("up", True)
elif resp.status_code in (401, 403):
# service reachable, but auth issue
return ("auth_error", True)
else:
return ("down", False)
except requests.RequestException:
return ("down", False)
def check_gemini(gemini_api_key: str, timeout: float = 4.0):
"""
Check Gemini (Google Generative API) availability by calling the models/list endpoint.
Returns tuple (status_str, is_reachable_bool).
status_str one of: "up", "down", "auth_error", "no_key"
NOTE: this uses the public generative.googleapis.com endpoint and passes the API key as 'key' param.
"""
if not gemini_api_key:
return ("no_key", False)
url = "https://generative.googleapis.com/v1/models"
try:
resp = requests.get(url, params={"key": gemini_api_key}, timeout=timeout)
if resp.status_code == 200:
return ("up", True)
elif resp.status_code in (401, 403):
return ("auth_error", True)
else:
return ("down", False)
except requests.RequestException:
return ("down", False)
# Run checks on load (will run each time the app reloads)
hf_token = None
gemini_key = None
try:
hf_token = st.secrets.get("HF_TOKEN")
except Exception:
hf_token = None
try:
gemini_key = st.secrets.get("GEMINI_API_KEY")
except Exception:
gemini_key = None
hf_status, hf_reachable = check_huggingface(hf_token)
gemini_status, gemini_reachable = check_gemini(gemini_key)
# Map status to CSS class + friendly text
def status_display_info(status_tuple):
status, reachable = status_tuple
if status == "up":
return ("service-up", "UP")
if status == "auth_error":
return ("service-auth", "AUTH")
if status in ("no_token", "no_key"):
return ("service-auth", "MISSING")
return ("service-down", "DOWN")
hf_class, hf_text = status_display_info((hf_status, hf_reachable))
gem_class, gem_text = status_display_info((gemini_status, gemini_reachable))
# Inject the side badges HTML
st.markdown(f"""
HF ยท {hf_text}
GEM ยท {gem_text}
Service Status
""", unsafe_allow_html=True)
# Also show clear, accessible textual statuses in the sidebar (useful for screen readers / logs)
with st.sidebar:
st.markdown("### API Health")
# HuggingFace
if hf_status == "up":
st.success("HuggingFace Inference: reachable (200 OK)")
elif hf_status == "auth_error":
st.warning("HuggingFace Inference: reachable but auth failed (401/403). Check HF_TOKEN.")
elif hf_status == "no_token":
st.error("HuggingFace Inference: HF_TOKEN missing in st.secrets.")
else:
st.error("HuggingFace Inference: not reachable (network or service error).")
# Gemini
if gemini_status == "up":
st.success("Gemini (Google Generative API): reachable (200 OK)")
elif gemini_status == "auth_error":
st.warning("Gemini (Google Generative API): reachable but auth failed (401/403). Check GEMINI_API_KEY.")
elif gemini_status == "no_key":
st.error("Gemini (Google Generative API): GEMINI_API_KEY missing in st.secrets.")
else:
st.error("Gemini (Google Generative API): not reachable (network or service error).")
st.markdown("---")
st.caption("Checks run on each page reload. Visual badges on the right show quick status.")
# -----------------------
# The rest of your code (unchanged apart from imports & checks above)
# -----------------------
# Initialize embeddings with the hf inference
embeddings = HuggingFaceInferenceAPIEmbeddings(
api_key=st.secrets["HF_TOKEN"],
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
def load_vector_store():
"""Load FAISS vector store with security checks"""
if not os.path.exists("faiss_index"):
st.error("Vector store not found. Please process documents first.")
return None
try:
return FAISS.load_local(
"faiss_index",
embeddings,
allow_dangerous_deserialization=True # Only for trusted sources
)
except Exception as e:
st.error(f"Error loading vector store: {str(e)}")
return None
def get_conversational_chain():
"""Create QA chain with proper model configuration"""
with open("system_prompt.md", "r") as f:
SYSTEM_PROMPT = f.read()
prompt_template = f"""
{SYSTEM_PROMPT}
Context:\n{{context}}\n
Query:\n{{question}}\n
Response:
"""
# Updated model name to the new google-genai supported model
model = ChatGoogleGenerativeAI(
model="gemini-2.0-flash-001",
google_api_key=st.secrets["GEMINI_API_KEY"],
temperature=0.2,
top_k=20,
top_p=0.95,
verbose=True
)
prompt = PromptTemplate(
template=prompt_template,
input_variables=["context", "question"]
)
return load_qa_chain(
model,
chain_type="stuff",
prompt=prompt,
verbose=False
)
def process_query(query):
"""Process user query with URL scraping and RAG"""
# URL detection and scraping
urls = re.findall(r'(https?://\S+)', query)
scraped_data = ""
if urls:
with st.status("๐ Scraping linked content...", expanded=True):
for url in urls:
scraped = scraping_utils.scrape_url(url)
scraped_data += f"\n\nScraped content from {url}:\n{scraped}"
full_query = query + scraped_data
with st.spinner("๐ Analyzing request with SHL knowledge base..."):
try:
db = load_vector_store()
if db is None:
return "Error: Knowledge base not loaded"
docs = db.similarity_search(full_query, k=10)
chain = get_conversational_chain()
response = chain.invoke(
{"input_documents": docs, "question": full_query},
return_only_outputs=True
)
return response["output_text"]
except Exception as e:
st.error(f"Analysis error: {str(e)}")
return None
def render_response(response):
"""Render AI response with beautiful markdown formatting"""
if not response:
return
# Extract sections between XML-like tags
sections = re.findall(r'<(\w+)>([\s\S]*?)\1>', response)
if not sections:
st.markdown(response)
return
# Reorder sections: move "result" to be the first tab if present
sections = sorted(sections, key=lambda x: 0 if x[0].lower() == "result" else 1)
# Create tabs for each section
tab_names = [sec[0].capitalize() for sec in sections]
tabs = st.tabs(tab_names)
for i, tab in enumerate(tabs):
with tab:
content = sections[i][1].strip()
if sections[i][0].lower() == "result":
st.markdown(content, unsafe_allow_html=True)
else:
st.markdown(f"""
{content}
""", unsafe_allow_html=True)
# Main application UI
st.title("SHL Assessment Recommendation System")
st.markdown("""
AI-Powered Assessment Matching Engine
Describe your assessment needs and get personalized recommendations from SHL's product catalog.
""", unsafe_allow_html=True)
query = st.text_area(
"Describe your assessment needs:",
placeholder="e.g. 'I need cognitive ability tests under 45 minutes for remote hiring of financial analysts...'",
height=150,
key="query_input"
)
if st.button("Generate Recommendations", type="primary"):
if not query:
st.warning("Please enter your assessment requirements")
else:
response = process_query(query)
if response:
render_response(response)