import streamlit as st import requests from pathlib import Path # ---------------------- CONFIG ---------------------- API_BASE_URL = "https://hamdy005-raij-ai.hf.space" API_URL_TEXT = f"{API_BASE_URL}/search/text" API_URL_IMAGE = f"{API_BASE_URL}/search/image" API_URL_AUDIO = f"{API_BASE_URL}/search/audio" API_URL_PRODUCT = f"{API_BASE_URL}/product" API_URL_RANDOM = f"{API_BASE_URL}/products/random" # ---------------------- PAGE CONFIG ---------------------- st.set_page_config( page_title="AI Smart Search", page_icon="🔎", layout="wide", initial_sidebar_state="collapsed" ) # ---------------------- LOAD CSS ---------------------- def load_css(): css_file = Path(__file__).parent / "styles.css" if css_file.exists(): with open(css_file) as f: st.markdown(f"", unsafe_allow_html=True) load_css() # ---------------------- SESSION STATE ---------------------- if "audio_key" not in st.session_state: st.session_state.audio_key = 0 if "search_results" not in st.session_state: st.session_state.search_results = None if "prediction_info" not in st.session_state: st.session_state.prediction_info = None if "selected_product" not in st.session_state: st.session_state.selected_product = None # ---------------------- HELPER FUNCTIONS ---------------------- def build_products_list(data): """Build products list from API response""" products = [] image_urls = data.get("image_urls", []) prices = data.get("prices", []) for idx, (pid, title) in enumerate(zip( data.get("product_ids", []), data.get("titles", []), )): products.append({ "id": pid, "title": title, "image_url": image_urls[idx] if idx < len(image_urls) else None, "price": prices[idx] if idx < len(prices) else None }) return products def display_products(products, prediction_info=None): """Display products in a grid layout""" # Show prediction info as a compact banner if prediction_info and prediction_info.get("type") != "text": if prediction_info.get("type") == "image": st.info(f"🖼️ **Detected Category:** {prediction_info.get('category', 'N/A')} ({prediction_info.get('confidence', 0):.1%} confidence)") elif prediction_info.get("type") == "audio": st.info(f"🎤 **Transcription:** \"{prediction_info.get('transcription', 'N/A')}\"") # Products header st.markdown(f"### 🛍️ {len(products)} Products Found") # Create grid layout - 4 columns for more compact display cols_per_row = 4 for i in range(0, len(products), cols_per_row): cols = st.columns(cols_per_row) for j, col in enumerate(cols): if i + j < len(products): product = products[i + j] with col: with st.container(border=True): # Show image first if available if product.get('image_url'): try: st.image(product['image_url'], width="stretch") except: st.markdown("🖼️ *No image*") # Product title (compact) title = product.get('title', 'N/A') display_title = title[:40] + '...' if len(title) > 40 else title st.markdown(f"**{display_title}**") # Price price = product.get('price') if price: st.markdown(f"${price:.2f}", unsafe_allow_html=True) # View Details button if st.button("View Details", key=f"view_{product.get('id')}", type="secondary"): st.session_state.selected_product = product.get('id') st.rerun() def display_product_details(product_id): """Display detailed product information""" # Back button if st.button("← Back to Results", type="secondary"): st.session_state.selected_product = None st.rerun() # Fetch product details from API with st.spinner("Loading product details..."): try: response = requests.get(f"{API_URL_PRODUCT}/{product_id}", timeout=30) if response.status_code == 200: product = response.json() if product.get("error"): st.error("Product not found") return # Layout: Image on left, details on right col1, col2 = st.columns([1, 2]) with col1: # Product images images = product.get("images", []) if images: st.image(images[0], width="stretch") # Show thumbnails if multiple images if len(images) > 1: thumb_cols = st.columns(min(4, len(images))) for idx, img_url in enumerate(images[:4]): with thumb_cols[idx]: st.image(img_url, width=80) else: st.markdown("🖼️ *No image available*") with col2: # Product title st.markdown(f"## {product.get('title', 'N/A')}") # SKU sku = product.get('sku') if sku: st.caption(f"SKU: {sku}") st.markdown("---") # Price section price = product.get('price') old_price = product.get('old_price') if old_price and old_price > price: discount = int((1 - price / old_price) * 100) st.markdown(f"""
${old_price:.2f} ${price:.2f} -{discount}%
""", unsafe_allow_html=True) elif price: st.markdown(f"${price:.2f}", unsafe_allow_html=True) else: st.markdown("*Price not available*") st.markdown("---") # Stock status stock = product.get('stock', 0) if stock > 0: st.success(f"✅ In Stock ({stock} available)") else: st.error("❌ Out of Stock") st.markdown("---") # Description st.markdown("### 📝 Description") description = product.get('description', 'No description available.') st.write(description) # Tags tags = product.get('tags', []) if tags: st.markdown("### 🏷️ Tags") st.write(" • ".join(tags)) else: st.error(f"Error loading product: {response.status_code}") except Exception as e: st.error(f"Error: {str(e)}") def reset_search(): """Reset search state""" st.session_state.search_results = None st.session_state.prediction_info = None st.session_state.selected_product = None def fetch_random_products(limit=10): """Fetch products from API""" try: response = requests.get(API_URL_RANDOM, params={"limit": limit}, timeout=30) if response.status_code == 200: data = response.json() return data.get("products", []) except Exception as e: print(f"Error fetching products: {e}") return [] def display_random_products(): """Display products in a grid layout""" st.markdown("### 🛍️ Discover Products") # Cache products in session state to avoid fetching on every rerun if "random_products" not in st.session_state: st.session_state.random_products = fetch_random_products(10) products = st.session_state.random_products if not products: st.info("No products available. Try searching instead!") return # Refresh button if st.button("🔄 Refresh Products", type="secondary"): st.session_state.random_products = fetch_random_products(10) st.rerun() # Create grid layout - 4 columns cols_per_row = 4 for i in range(0, len(products), cols_per_row): cols = st.columns(cols_per_row) for j, col in enumerate(cols): if i + j < len(products): product = products[i + j] with col: with st.container(border=True): # Show image first if available if product.get('image_url'): try: st.image(product['image_url'], width="stretch") except: st.markdown("🖼️ *No image*") # Product title (compact) title = product.get('title', 'N/A') display_title = title[:40] + '...' if len(title) > 40 else title st.markdown(f"**{display_title}**") # Price price = product.get('price') if price: st.markdown(f"${price:.2f}", unsafe_allow_html=True) # View Details button if st.button("View Details", key=f"random_{product.get('id')}", type="secondary"): st.session_state.selected_product = product.get('id') st.rerun() # ---------------------- HEADER ---------------------- st.markdown("""

🔎 AI Smart Search

Search by Text, Image, or Voice

""", unsafe_allow_html=True) # ---------------------- COMPACT SEARCH BAR ---------------------- # All search controls in one row search_col1, search_col2, search_col3 = st.columns([1, 4, 1]) with search_col1: search_type = st.selectbox( "Type", ["📝 Text", "🖼️ Image", "🎤 Audio"], label_visibility="collapsed" ) with search_col2: # Dynamic input based on search type if search_type == "📝 Text": query = st.text_input( "Search", placeholder="Search for products...", label_visibility="collapsed" ) elif search_type == "🖼️ Image": image_file = st.file_uploader( "Upload image", type=["png", "jpg", "jpeg"], label_visibility="collapsed" ) else: # Audio audio_col1, audio_col2 = st.columns([3, 1]) with audio_col1: audio_method = st.radio( "Method", ["🎙️ Record", "📁 Upload"], horizontal=True, label_visibility="collapsed" ) with audio_col2: language = st.selectbox( "Lang", ["en", "ar"], format_func=lambda x: "EN" if x == "en" else "AR", label_visibility="collapsed" ) with search_col3: top_k = st.selectbox( "Results", [10, 20, 30, 50], label_visibility="collapsed" ) # ---------------------- SEARCH INPUT ROW 2 (for Image/Audio) ---------------------- if search_type == "🖼️ Image" and image_file: preview_col, btn_col = st.columns([1, 3]) with preview_col: st.image(image_file, width=100) with btn_col: search_btn = st.button("🔍 Search", type="primary") if search_btn: with st.spinner("Analyzing..."): try: image_file.seek(0) files = {"image": (image_file.name, image_file, image_file.type)} response = requests.post(API_URL_IMAGE, files=files, params={"top_k": top_k}, timeout=60) if response.status_code == 200: data = response.json() st.session_state.search_results = build_products_list(data) st.session_state.prediction_info = { "type": "image", "category": data.get("predicted_category", "Unknown"), "confidence": data.get("confidence_score", 0) } st.rerun() else: st.error(f"Error: {response.status_code}") except Exception as e: st.error(str(e)) elif search_type == "🎤 Audio": audio_data = None audio_filename = "recording.wav" if audio_method == "🎙️ Record": recorded = st.audio_input("Record", key=f"rec_{st.session_state.audio_key}", label_visibility="collapsed") if recorded: audio_data = recorded else: uploaded = st.file_uploader("Upload", type=["wav", "mp3", "m4a"], key=f"up_{st.session_state.audio_key}", label_visibility="collapsed") if uploaded: audio_data = uploaded audio_filename = uploaded.name if audio_data: col1, col2 = st.columns([1, 2]) with col1: st.audio(audio_data) with col2: if st.button("🔍 Search", type="primary"): with st.spinner("Transcribing..."): try: audio_data.seek(0) files = {"audio": (audio_filename, audio_data, "audio/wav")} response = requests.post(API_URL_AUDIO, files=files, params={"top_k": top_k, "language": language}, timeout=60) if response.status_code == 200: data = response.json() st.session_state.search_results = build_products_list(data) st.session_state.prediction_info = { "type": "audio", "transcription": data.get("caption", "") } st.session_state.audio_key += 1 st.rerun() else: st.error(f"Error: {response.status_code}") except Exception as e: st.error(str(e)) elif search_type == "📝 Text": if st.button("🔍 Search", type="primary") or (query and st.session_state.get("last_query") != query): if query and query.strip(): with st.spinner("Searching..."): try: response = requests.post(API_URL_TEXT, params={"query": query, "top_k": top_k}, timeout=30) if response.status_code == 200: data = response.json() st.session_state.search_results = build_products_list(data) st.session_state.prediction_info = {"type": "text", "query": query} st.session_state.last_query = query st.rerun() else: st.error(f"Error: {response.status_code}") except Exception as e: st.error(str(e)) # ---------------------- CLEAR BUTTON ---------------------- if st.session_state.search_results and not st.session_state.selected_product: if st.button("🗑️ Clear", type="secondary"): reset_search() st.rerun() # ---------------------- DISPLAY RESULTS ---------------------- st.markdown("---") # Show product details if a product is selected if st.session_state.selected_product: display_product_details(st.session_state.selected_product) elif st.session_state.search_results: display_products( st.session_state.search_results, st.session_state.prediction_info ) else: # Show random products before any search display_random_products()