Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import chromadb | |
| from sentence_transformers import SentenceTransformer | |
| import numpy as np | |
| import json | |
| import math | |
| import re | |
| # st.title("hello") | |
| st.set_page_config(layout="wide") | |
| # --- Configuration --- | |
| CSV_FILE = "shl_data.csv" | |
| COLLECTION_NAME = "shl_assessments" | |
| # Use a robust model good for semantic search | |
| MODEL_NAME = 'msmarco-distilbert-base-v4' # Or 'all-MiniLM-L6-v2' | |
| # --- Caching Functions --- | |
| # Cache the embedding model loading | |
| def load_embedding_model(model_name=MODEL_NAME): | |
| """Loads the Sentence Transformer model.""" | |
| print("Loading embedding model...") | |
| try: | |
| model = SentenceTransformer(model_name) | |
| print("Embedding model loaded.") | |
| return model | |
| except Exception as e: | |
| st.error(f"Error loading embedding model '{model_name}': {e}") | |
| return None | |
| # Cache the ChromaDB client and collection setup | |
| def setup_chroma_collection(collection_name=COLLECTION_NAME, model_name=MODEL_NAME): | |
| """Initializes ChromaDB client and collection, loading data if empty.""" | |
| print("Setting up ChromaDB collection...") | |
| try: | |
| # Using an in-memory client suitable for Streamlit sharing / HF Spaces | |
| client = chromadb.Client() | |
| # Use the SentenceTransformerEmbeddingFunction for automatic embedding | |
| embedding_function = chromadb.utils.embedding_functions.SentenceTransformerEmbeddingFunction(model_name=model_name) | |
| collection = client.get_or_create_collection( | |
| name=collection_name, | |
| embedding_function=embedding_function | |
| # metadata={"hnsw:space": "cosine"} # Optional: ensure cosine distance | |
| ) | |
| print(f"ChromaDB collection '{collection_name}' retrieved/created.") | |
| # Load and preprocess data only if collection is empty | |
| if collection.count() == 0: | |
| print("Collection is empty. Loading data from CSV...") | |
| try: | |
| df = pd.read_csv(CSV_FILE) | |
| except FileNotFoundError: | |
| st.error(f"Error: Data file '{CSV_FILE}' not found. Make sure it's in the same directory as app.py.") | |
| return None | |
| except Exception as e: | |
| st.error(f"Error reading CSV file: {e}") | |
| return None | |
| # --- Data Cleaning and Preprocessing (same as Colab) --- | |
| df.rename(columns={ | |
| 'Link': 'url', 'Assessment Name': 'name', 'Remote Testing': 'remote_support', | |
| 'Adaptive/IRT': 'adaptive_support', 'Assessment Length': 'duration', | |
| 'Test Type': 'test_type_raw', 'Description': 'description' | |
| }, inplace=True) | |
| df['description'].fillna('No description available.', inplace=True) | |
| df['name'].fillna('Unnamed Assessment', inplace=True) | |
| for col in ['remote_support', 'adaptive_support']: | |
| if col in df.columns: | |
| df[col] = df[col].astype(str).str.strip().str.lower().apply(lambda x: 'Yes' if x == 'yes' else 'No') | |
| else: df[col] = 'No' | |
| if 'duration' in df.columns: | |
| df['duration'] = pd.to_numeric(df['duration'], errors='coerce').fillna(0).astype(int) | |
| else: df['duration'] = 0 | |
| if 'test_type_raw' in df.columns: | |
| df['test_type_list'] = df['test_type_raw'].fillna('').astype(str).apply( | |
| lambda x: [t.strip() for t in x.split(',') if t.strip()] | |
| ) | |
| type_mapping = { | |
| 'A': 'Ability', 'B': 'Behavior', 'C': 'Cognitive', 'P': 'Personality', | |
| 'S': 'Simulation', 'K': 'Knowledge & Skills', 'D': 'Development', 'E': 'Exercise' | |
| } | |
| df['test_type_list'] = df['test_type_list'].apply(lambda types: list(set([type_mapping.get(t, t) for t in types]))) | |
| else: df['test_type_list'] = [[] for _ in range(len(df))] | |
| df.dropna(subset=['url', 'name'], inplace=True) | |
| df = df[df['url'].str.startswith('http')] | |
| # ------------------------------------------------------- | |
| # --- Prepare for ChromaDB --- | |
| documents = [] | |
| metadatas = [] | |
| ids = [] | |
| required_fields_for_api = ['url', 'adaptive_support', 'description', 'duration', 'remote_support'] | |
| for index, row in df.iterrows(): | |
| doc_text = f"{row['name']}: {row['description']}" | |
| documents.append(re.sub(r'\s+', ' ', doc_text).strip()) | |
| meta = {field: row[field] for field in required_fields_for_api if field in row} | |
| meta['url'] = str(meta.get('url', '')) | |
| meta['adaptive_support'] = str(meta.get('adaptive_support', 'No')) | |
| meta['description'] = str(meta.get('description', 'No description available.')) | |
| meta['duration'] = int(meta.get('duration', 0)) | |
| meta['remote_support'] = str(meta.get('remote_support', 'No')) | |
| meta['name'] = str(row['name']) | |
| test_type_list = row['test_type_list'] if 'test_type_list' in row and isinstance(row['test_type_list'], list) else [] | |
| meta['test_type_json'] = json.dumps(test_type_list) # Store as JSON string | |
| metadatas.append(meta) | |
| ids.append(f"shl_assessment_{index}") # Make sure IDs are strings | |
| # -------------------------- | |
| if not ids: | |
| st.warning("No valid data found in the CSV to add to the database.") | |
| return collection # Return empty collection | |
| print(f"Adding {len(ids)} items to the collection...") | |
| # Add data in batches if necessary (though for this size, one go is fine) | |
| batch_size = 100 | |
| for i in range(0, len(ids), batch_size): | |
| collection.add( | |
| ids=ids[i:i+batch_size], | |
| documents=documents[i:i+batch_size], | |
| metadatas=metadatas[i:i+batch_size] | |
| ) | |
| print("Data added successfully.") | |
| print(f"ChromaDB setup complete. Collection size: {collection.count()}") | |
| return collection | |
| except Exception as e: | |
| st.error(f"Error setting up ChromaDB: {e}") | |
| print(f"!!! Error setting up ChromaDB: {e}") # Also print to console | |
| return None | |
| # --- Query Function --- | |
| def get_recommendations_from_chroma(query_text, collection, top_n=10): | |
| """Queries the ChromaDB collection and formats results for API spec.""" | |
| if collection is None or collection.count() == 0: | |
| print("Collection is not available or empty.") | |
| return {"recommended_assessments": []} | |
| try: | |
| results = collection.query( | |
| query_texts=[query_text], | |
| n_results=min(top_n * 2, collection.count()), # Retrieve more initially for potential filtering | |
| include=['metadatas', 'distances'] | |
| ) | |
| except Exception as e: | |
| st.error(f"Error querying ChromaDB: {e}") | |
| print(f"!!! Error querying ChromaDB: {e}") | |
| return {"recommended_assessments": []} | |
| recommended_assessments = [] | |
| seen_urls = set() # Avoid duplicates if any slipped through | |
| if results and results.get('ids') and results['ids'][0]: | |
| for i, item_id in enumerate(results['ids'][0]): | |
| if len(recommended_assessments) >= top_n: # Stop once we have enough | |
| break | |
| meta = results['metadatas'][0][i] | |
| # distance = results['distances'][0][i] # Lower distance = more similar | |
| # Basic check for duplicate URLs | |
| url = meta.get('url', '') | |
| if not url or url in seen_urls: | |
| continue | |
| seen_urls.add(url) | |
| # Parse test_type from JSON string | |
| test_type_list = [] | |
| test_type_json_str = meta.get('test_type_json', '[]') | |
| try: | |
| test_type_list = json.loads(test_type_json_str) | |
| if not isinstance(test_type_list, list): test_type_list = [] | |
| except json.JSONDecodeError: | |
| print(f"Warning: Could not parse test_type_json for ID {item_id}: {test_type_json_str}") | |
| test_type_list = [] | |
| # Format according to API spec | |
| formatted_result = { | |
| "url": url, | |
| "adaptive_support": meta.get('adaptive_support', 'No'), | |
| "description": meta.get('description', 'No description available.'), | |
| "duration": int(meta.get('duration', 0)), | |
| "remote_support": meta.get('remote_support', 'No'), | |
| "test_type": test_type_list, | |
| # Include name for display purposes in Streamlit | |
| "name": meta.get('name', 'Unknown Assessment') | |
| } | |
| recommended_assessments.append(formatted_result) | |
| # Ensure minimum 1 result if possible (and max 10) | |
| final_recommendations = recommended_assessments[:top_n] | |
| if not final_recommendations and collection.count() > 0: | |
| print("Query returned no results, attempting fallback peek...") | |
| try: | |
| fallback_results = collection.peek(limit=1) # Get the 'first' item | |
| if fallback_results and fallback_results.get('ids'): | |
| meta = fallback_results['metadatas'][0] | |
| test_type_list_fb = [] | |
| test_type_json_str_fb = meta.get('test_type_json', '[]') | |
| try: test_type_list_fb = json.loads(test_type_json_str_fb) | |
| except: pass | |
| final_recommendations.append({ | |
| "url": meta.get('url', ''), | |
| "adaptive_support": meta.get('adaptive_support', 'No'), | |
| "description": meta.get('description', 'No description available.'), | |
| "duration": int(meta.get('duration', 0)), | |
| "remote_support": meta.get('remote_support', 'No'), | |
| "test_type": test_type_list_fb if isinstance(test_type_list_fb, list) else [], | |
| "name": meta.get('name', 'Unknown Assessment') | |
| }) | |
| except Exception as fb_e: | |
| print(f"Error during fallback peek: {fb_e}") | |
| return {"recommended_assessments": final_recommendations} | |
| # --- Streamlit App UI --- | |
| st.title("🚀 SHL Assessment Recommendation System") | |
| st.markdown("Enter a natural language query or job description text to find relevant SHL assessments.") | |
| # Load model and collection (cached) | |
| # model = load_embedding_model() # Model is implicitly used by Chroma's embedding function | |
| collection = setup_chroma_collection() | |
| # User Input | |
| query = st.text_area("Enter your query or job description:", height=150) | |
| # Search Button | |
| search_button = st.button("Find Assessments") | |
| if search_button and query: | |
| if collection is not None: | |
| with st.spinner("Searching for relevant assessments..."): | |
| results_data = get_recommendations_from_chroma(query, collection, top_n=10) | |
| recommendations = results_data.get("recommended_assessments", []) | |
| st.subheader(f"Top {len(recommendations)} Recommendations:") | |
| if recommendations: | |
| for i, rec in enumerate(recommendations): | |
| st.markdown(f"---") | |
| st.markdown(f"**{i+1}. {rec.get('name', 'N/A')}**") | |
| st.markdown(f"**URL:** [{rec.get('url')}]({rec.get('url')})") | |
| st.markdown(f"**Description:** {rec.get('description')}") | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| st.markdown(f"**Duration:** {rec.get('duration', 'N/A')} min") | |
| with col2: | |
| st.markdown(f"**Remote Support:** {rec.get('remote_support', 'N/A')}") | |
| with col3: | |
| st.markdown(f"**Adaptive/IRT:** {rec.get('adaptive_support', 'N/A')}") | |
| # Display test types as a comma-separated string | |
| test_types_str = ", ".join(rec.get('test_type', [])) | |
| st.markdown(f"**Test Type(s):** {test_types_str if test_types_str else 'N/A'}") | |
| else: | |
| st.warning("No relevant assessments found for your query.") | |
| else: | |
| st.error("Database collection could not be loaded. Please check logs.") | |
| elif search_button and not query: | |
| st.warning("Please enter a query.") |