GS123's picture
Upload 2 files
8964653 verified
import streamlit as st
import pandas as pd
import chromadb
from sentence_transformers import SentenceTransformer
import numpy as np
import json
import math
import re
# st.title("hello")
st.set_page_config(layout="wide")
# --- Configuration ---
CSV_FILE = "shl_data.csv"
COLLECTION_NAME = "shl_assessments"
# Use a robust model good for semantic search
MODEL_NAME = 'msmarco-distilbert-base-v4' # Or 'all-MiniLM-L6-v2'
# --- Caching Functions ---
# Cache the embedding model loading
@st.cache_resource
def load_embedding_model(model_name=MODEL_NAME):
"""Loads the Sentence Transformer model."""
print("Loading embedding model...")
try:
model = SentenceTransformer(model_name)
print("Embedding model loaded.")
return model
except Exception as e:
st.error(f"Error loading embedding model '{model_name}': {e}")
return None
# Cache the ChromaDB client and collection setup
@st.cache_resource
def setup_chroma_collection(collection_name=COLLECTION_NAME, model_name=MODEL_NAME):
"""Initializes ChromaDB client and collection, loading data if empty."""
print("Setting up ChromaDB collection...")
try:
# Using an in-memory client suitable for Streamlit sharing / HF Spaces
client = chromadb.Client()
# Use the SentenceTransformerEmbeddingFunction for automatic embedding
embedding_function = chromadb.utils.embedding_functions.SentenceTransformerEmbeddingFunction(model_name=model_name)
collection = client.get_or_create_collection(
name=collection_name,
embedding_function=embedding_function
# metadata={"hnsw:space": "cosine"} # Optional: ensure cosine distance
)
print(f"ChromaDB collection '{collection_name}' retrieved/created.")
# Load and preprocess data only if collection is empty
if collection.count() == 0:
print("Collection is empty. Loading data from CSV...")
try:
df = pd.read_csv(CSV_FILE)
except FileNotFoundError:
st.error(f"Error: Data file '{CSV_FILE}' not found. Make sure it's in the same directory as app.py.")
return None
except Exception as e:
st.error(f"Error reading CSV file: {e}")
return None
# --- Data Cleaning and Preprocessing (same as Colab) ---
df.rename(columns={
'Link': 'url', 'Assessment Name': 'name', 'Remote Testing': 'remote_support',
'Adaptive/IRT': 'adaptive_support', 'Assessment Length': 'duration',
'Test Type': 'test_type_raw', 'Description': 'description'
}, inplace=True)
df['description'].fillna('No description available.', inplace=True)
df['name'].fillna('Unnamed Assessment', inplace=True)
for col in ['remote_support', 'adaptive_support']:
if col in df.columns:
df[col] = df[col].astype(str).str.strip().str.lower().apply(lambda x: 'Yes' if x == 'yes' else 'No')
else: df[col] = 'No'
if 'duration' in df.columns:
df['duration'] = pd.to_numeric(df['duration'], errors='coerce').fillna(0).astype(int)
else: df['duration'] = 0
if 'test_type_raw' in df.columns:
df['test_type_list'] = df['test_type_raw'].fillna('').astype(str).apply(
lambda x: [t.strip() for t in x.split(',') if t.strip()]
)
type_mapping = {
'A': 'Ability', 'B': 'Behavior', 'C': 'Cognitive', 'P': 'Personality',
'S': 'Simulation', 'K': 'Knowledge & Skills', 'D': 'Development', 'E': 'Exercise'
}
df['test_type_list'] = df['test_type_list'].apply(lambda types: list(set([type_mapping.get(t, t) for t in types])))
else: df['test_type_list'] = [[] for _ in range(len(df))]
df.dropna(subset=['url', 'name'], inplace=True)
df = df[df['url'].str.startswith('http')]
# -------------------------------------------------------
# --- Prepare for ChromaDB ---
documents = []
metadatas = []
ids = []
required_fields_for_api = ['url', 'adaptive_support', 'description', 'duration', 'remote_support']
for index, row in df.iterrows():
doc_text = f"{row['name']}: {row['description']}"
documents.append(re.sub(r'\s+', ' ', doc_text).strip())
meta = {field: row[field] for field in required_fields_for_api if field in row}
meta['url'] = str(meta.get('url', ''))
meta['adaptive_support'] = str(meta.get('adaptive_support', 'No'))
meta['description'] = str(meta.get('description', 'No description available.'))
meta['duration'] = int(meta.get('duration', 0))
meta['remote_support'] = str(meta.get('remote_support', 'No'))
meta['name'] = str(row['name'])
test_type_list = row['test_type_list'] if 'test_type_list' in row and isinstance(row['test_type_list'], list) else []
meta['test_type_json'] = json.dumps(test_type_list) # Store as JSON string
metadatas.append(meta)
ids.append(f"shl_assessment_{index}") # Make sure IDs are strings
# --------------------------
if not ids:
st.warning("No valid data found in the CSV to add to the database.")
return collection # Return empty collection
print(f"Adding {len(ids)} items to the collection...")
# Add data in batches if necessary (though for this size, one go is fine)
batch_size = 100
for i in range(0, len(ids), batch_size):
collection.add(
ids=ids[i:i+batch_size],
documents=documents[i:i+batch_size],
metadatas=metadatas[i:i+batch_size]
)
print("Data added successfully.")
print(f"ChromaDB setup complete. Collection size: {collection.count()}")
return collection
except Exception as e:
st.error(f"Error setting up ChromaDB: {e}")
print(f"!!! Error setting up ChromaDB: {e}") # Also print to console
return None
# --- Query Function ---
def get_recommendations_from_chroma(query_text, collection, top_n=10):
"""Queries the ChromaDB collection and formats results for API spec."""
if collection is None or collection.count() == 0:
print("Collection is not available or empty.")
return {"recommended_assessments": []}
try:
results = collection.query(
query_texts=[query_text],
n_results=min(top_n * 2, collection.count()), # Retrieve more initially for potential filtering
include=['metadatas', 'distances']
)
except Exception as e:
st.error(f"Error querying ChromaDB: {e}")
print(f"!!! Error querying ChromaDB: {e}")
return {"recommended_assessments": []}
recommended_assessments = []
seen_urls = set() # Avoid duplicates if any slipped through
if results and results.get('ids') and results['ids'][0]:
for i, item_id in enumerate(results['ids'][0]):
if len(recommended_assessments) >= top_n: # Stop once we have enough
break
meta = results['metadatas'][0][i]
# distance = results['distances'][0][i] # Lower distance = more similar
# Basic check for duplicate URLs
url = meta.get('url', '')
if not url or url in seen_urls:
continue
seen_urls.add(url)
# Parse test_type from JSON string
test_type_list = []
test_type_json_str = meta.get('test_type_json', '[]')
try:
test_type_list = json.loads(test_type_json_str)
if not isinstance(test_type_list, list): test_type_list = []
except json.JSONDecodeError:
print(f"Warning: Could not parse test_type_json for ID {item_id}: {test_type_json_str}")
test_type_list = []
# Format according to API spec
formatted_result = {
"url": url,
"adaptive_support": meta.get('adaptive_support', 'No'),
"description": meta.get('description', 'No description available.'),
"duration": int(meta.get('duration', 0)),
"remote_support": meta.get('remote_support', 'No'),
"test_type": test_type_list,
# Include name for display purposes in Streamlit
"name": meta.get('name', 'Unknown Assessment')
}
recommended_assessments.append(formatted_result)
# Ensure minimum 1 result if possible (and max 10)
final_recommendations = recommended_assessments[:top_n]
if not final_recommendations and collection.count() > 0:
print("Query returned no results, attempting fallback peek...")
try:
fallback_results = collection.peek(limit=1) # Get the 'first' item
if fallback_results and fallback_results.get('ids'):
meta = fallback_results['metadatas'][0]
test_type_list_fb = []
test_type_json_str_fb = meta.get('test_type_json', '[]')
try: test_type_list_fb = json.loads(test_type_json_str_fb)
except: pass
final_recommendations.append({
"url": meta.get('url', ''),
"adaptive_support": meta.get('adaptive_support', 'No'),
"description": meta.get('description', 'No description available.'),
"duration": int(meta.get('duration', 0)),
"remote_support": meta.get('remote_support', 'No'),
"test_type": test_type_list_fb if isinstance(test_type_list_fb, list) else [],
"name": meta.get('name', 'Unknown Assessment')
})
except Exception as fb_e:
print(f"Error during fallback peek: {fb_e}")
return {"recommended_assessments": final_recommendations}
# --- Streamlit App UI ---
st.title("🚀 SHL Assessment Recommendation System")
st.markdown("Enter a natural language query or job description text to find relevant SHL assessments.")
# Load model and collection (cached)
# model = load_embedding_model() # Model is implicitly used by Chroma's embedding function
collection = setup_chroma_collection()
# User Input
query = st.text_area("Enter your query or job description:", height=150)
# Search Button
search_button = st.button("Find Assessments")
if search_button and query:
if collection is not None:
with st.spinner("Searching for relevant assessments..."):
results_data = get_recommendations_from_chroma(query, collection, top_n=10)
recommendations = results_data.get("recommended_assessments", [])
st.subheader(f"Top {len(recommendations)} Recommendations:")
if recommendations:
for i, rec in enumerate(recommendations):
st.markdown(f"---")
st.markdown(f"**{i+1}. {rec.get('name', 'N/A')}**")
st.markdown(f"**URL:** [{rec.get('url')}]({rec.get('url')})")
st.markdown(f"**Description:** {rec.get('description')}")
col1, col2, col3 = st.columns(3)
with col1:
st.markdown(f"**Duration:** {rec.get('duration', 'N/A')} min")
with col2:
st.markdown(f"**Remote Support:** {rec.get('remote_support', 'N/A')}")
with col3:
st.markdown(f"**Adaptive/IRT:** {rec.get('adaptive_support', 'N/A')}")
# Display test types as a comma-separated string
test_types_str = ", ".join(rec.get('test_type', []))
st.markdown(f"**Test Type(s):** {test_types_str if test_types_str else 'N/A'}")
else:
st.warning("No relevant assessments found for your query.")
else:
st.error("Database collection could not be loaded. Please check logs.")
elif search_button and not query:
st.warning("Please enter a query.")