Spaces:
Paused
Paused
| import streamlit as st | |
| import requests | |
| from pathlib import Path | |
| # ---------------------- CONFIG ---------------------- | |
| API_BASE_URL = "https://hamdy005-raij-ai.hf.space" | |
| API_URL_TEXT = f"{API_BASE_URL}/search/text" | |
| API_URL_IMAGE = f"{API_BASE_URL}/search/image" | |
| API_URL_AUDIO = f"{API_BASE_URL}/search/audio" | |
| API_URL_PRODUCT = f"{API_BASE_URL}/product" | |
| API_URL_RANDOM = f"{API_BASE_URL}/products/random" | |
| # ---------------------- PAGE CONFIG ---------------------- | |
| st.set_page_config( | |
| page_title="AI Smart Search", | |
| page_icon="π", | |
| layout="wide", | |
| initial_sidebar_state="collapsed" | |
| ) | |
| # ---------------------- LOAD CSS ---------------------- | |
| def load_css(): | |
| css_file = Path(__file__).parent / "styles.css" | |
| if css_file.exists(): | |
| with open(css_file) as f: | |
| st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) | |
| load_css() | |
| # ---------------------- SESSION STATE ---------------------- | |
| if "audio_key" not in st.session_state: | |
| st.session_state.audio_key = 0 | |
| if "search_results" not in st.session_state: | |
| st.session_state.search_results = None | |
| if "prediction_info" not in st.session_state: | |
| st.session_state.prediction_info = None | |
| if "selected_product" not in st.session_state: | |
| st.session_state.selected_product = None | |
| # ---------------------- HELPER FUNCTIONS ---------------------- | |
| def build_products_list(data): | |
| """Build products list from API response""" | |
| products = [] | |
| image_urls = data.get("image_urls", []) | |
| prices = data.get("prices", []) | |
| for idx, (pid, title) in enumerate(zip( | |
| data.get("product_ids", []), | |
| data.get("titles", []), | |
| )): | |
| products.append({ | |
| "id": pid, | |
| "title": title, | |
| "image_url": image_urls[idx] if idx < len(image_urls) else None, | |
| "price": prices[idx] if idx < len(prices) else None | |
| }) | |
| return products | |
| def display_products(products, prediction_info=None): | |
| """Display products in a grid layout""" | |
| # Show prediction info as a compact banner | |
| if prediction_info and prediction_info.get("type") != "text": | |
| if prediction_info.get("type") == "image": | |
| st.info(f"πΌοΈ **Detected Category:** {prediction_info.get('category', 'N/A')} ({prediction_info.get('confidence', 0):.1%} confidence)") | |
| elif prediction_info.get("type") == "audio": | |
| st.info(f"π€ **Transcription:** \"{prediction_info.get('transcription', 'N/A')}\"") | |
| # Products header | |
| st.markdown(f"### ποΈ {len(products)} Products Found") | |
| # Create grid layout - 4 columns for more compact display | |
| cols_per_row = 4 | |
| for i in range(0, len(products), cols_per_row): | |
| cols = st.columns(cols_per_row) | |
| for j, col in enumerate(cols): | |
| if i + j < len(products): | |
| product = products[i + j] | |
| with col: | |
| with st.container(border=True): | |
| # Show image first if available | |
| if product.get('image_url'): | |
| try: | |
| st.image(product['image_url'], width="stretch") | |
| except: | |
| st.markdown("πΌοΈ *No image*") | |
| # Product title (compact) | |
| title = product.get('title', 'N/A') | |
| display_title = title[:40] + '...' if len(title) > 40 else title | |
| st.markdown(f"**{display_title}**") | |
| # Price | |
| price = product.get('price') | |
| if price: | |
| st.markdown(f"<span style='color: #2ecc71; font-weight: bold;'>${price:.2f}</span>", unsafe_allow_html=True) | |
| # View Details button | |
| if st.button("View Details", key=f"view_{product.get('id')}", type="secondary"): | |
| st.session_state.selected_product = product.get('id') | |
| st.rerun() | |
| def display_product_details(product_id): | |
| """Display detailed product information""" | |
| # Back button | |
| if st.button("β Back to Results", type="secondary"): | |
| st.session_state.selected_product = None | |
| st.rerun() | |
| # Fetch product details from API | |
| with st.spinner("Loading product details..."): | |
| try: | |
| response = requests.get(f"{API_URL_PRODUCT}/{product_id}", timeout=30) | |
| if response.status_code == 200: | |
| product = response.json() | |
| if product.get("error"): | |
| st.error("Product not found") | |
| return | |
| # Layout: Image on left, details on right | |
| col1, col2 = st.columns([1, 2]) | |
| with col1: | |
| # Product images | |
| images = product.get("images", []) | |
| if images: | |
| st.image(images[0], width="stretch") | |
| # Show thumbnails if multiple images | |
| if len(images) > 1: | |
| thumb_cols = st.columns(min(4, len(images))) | |
| for idx, img_url in enumerate(images[:4]): | |
| with thumb_cols[idx]: | |
| st.image(img_url, width=80) | |
| else: | |
| st.markdown("πΌοΈ *No image available*") | |
| with col2: | |
| # Product title | |
| st.markdown(f"## {product.get('title', 'N/A')}") | |
| # SKU | |
| sku = product.get('sku') | |
| if sku: | |
| st.caption(f"SKU: {sku}") | |
| st.markdown("---") | |
| # Price section | |
| price = product.get('price') | |
| old_price = product.get('old_price') | |
| if old_price and old_price > price: | |
| discount = int((1 - price / old_price) * 100) | |
| st.markdown(f""" | |
| <div style='margin: 10px 0;'> | |
| <span style='text-decoration: line-through; color: #888; font-size: 1.2rem;'>${old_price:.2f}</span> | |
| <span style='color: #e74c3c; font-size: 1.8rem; font-weight: bold; margin-left: 10px;'>${price:.2f}</span> | |
| <span style='background: #e74c3c; color: white; padding: 3px 8px; border-radius: 4px; margin-left: 10px; font-size: 0.9rem;'>-{discount}%</span> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| elif price: | |
| st.markdown(f"<span style='color: #2ecc71; font-size: 1.8rem; font-weight: bold;'>${price:.2f}</span>", unsafe_allow_html=True) | |
| else: | |
| st.markdown("*Price not available*") | |
| st.markdown("---") | |
| # Stock status | |
| stock = product.get('stock', 0) | |
| if stock > 0: | |
| st.success(f"β In Stock ({stock} available)") | |
| else: | |
| st.error("β Out of Stock") | |
| st.markdown("---") | |
| # Description | |
| st.markdown("### π Description") | |
| description = product.get('description', 'No description available.') | |
| st.write(description) | |
| # Tags | |
| tags = product.get('tags', []) | |
| if tags: | |
| st.markdown("### π·οΈ Tags") | |
| st.write(" β’ ".join(tags)) | |
| else: | |
| st.error(f"Error loading product: {response.status_code}") | |
| except Exception as e: | |
| st.error(f"Error: {str(e)}") | |
| def reset_search(): | |
| """Reset search state""" | |
| st.session_state.search_results = None | |
| st.session_state.prediction_info = None | |
| st.session_state.selected_product = None | |
| def fetch_random_products(limit=10): | |
| """Fetch products from API""" | |
| try: | |
| response = requests.get(API_URL_RANDOM, params={"limit": limit}, timeout=30) | |
| if response.status_code == 200: | |
| data = response.json() | |
| return data.get("products", []) | |
| except Exception as e: | |
| print(f"Error fetching products: {e}") | |
| return [] | |
| def display_random_products(): | |
| """Display products in a grid layout""" | |
| st.markdown("### ποΈ Discover Products") | |
| # Cache products in session state to avoid fetching on every rerun | |
| if "random_products" not in st.session_state: | |
| st.session_state.random_products = fetch_random_products(10) | |
| products = st.session_state.random_products | |
| if not products: | |
| st.info("No products available. Try searching instead!") | |
| return | |
| # Refresh button | |
| if st.button("π Refresh Products", type="secondary"): | |
| st.session_state.random_products = fetch_random_products(10) | |
| st.rerun() | |
| # Create grid layout - 4 columns | |
| cols_per_row = 4 | |
| for i in range(0, len(products), cols_per_row): | |
| cols = st.columns(cols_per_row) | |
| for j, col in enumerate(cols): | |
| if i + j < len(products): | |
| product = products[i + j] | |
| with col: | |
| with st.container(border=True): | |
| # Show image first if available | |
| if product.get('image_url'): | |
| try: | |
| st.image(product['image_url'], width="stretch") | |
| except: | |
| st.markdown("πΌοΈ *No image*") | |
| # Product title (compact) | |
| title = product.get('title', 'N/A') | |
| display_title = title[:40] + '...' if len(title) > 40 else title | |
| st.markdown(f"**{display_title}**") | |
| # Price | |
| price = product.get('price') | |
| if price: | |
| st.markdown(f"<span style='color: #2ecc71; font-weight: bold;'>${price:.2f}</span>", unsafe_allow_html=True) | |
| # View Details button | |
| if st.button("View Details", key=f"random_{product.get('id')}", type="secondary"): | |
| st.session_state.selected_product = product.get('id') | |
| st.rerun() | |
| # ---------------------- HEADER ---------------------- | |
| st.markdown(""" | |
| <div class="main-header"> | |
| <h1>π AI Smart Search</h1> | |
| <p>Search by Text, Image, or Voice</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # ---------------------- COMPACT SEARCH BAR ---------------------- | |
| # All search controls in one row | |
| search_col1, search_col2, search_col3 = st.columns([1, 4, 1]) | |
| with search_col1: | |
| search_type = st.selectbox( | |
| "Type", | |
| ["π Text", "πΌοΈ Image", "π€ Audio"], | |
| label_visibility="collapsed" | |
| ) | |
| with search_col2: | |
| # Dynamic input based on search type | |
| if search_type == "π Text": | |
| query = st.text_input( | |
| "Search", | |
| placeholder="Search for products...", | |
| label_visibility="collapsed" | |
| ) | |
| elif search_type == "πΌοΈ Image": | |
| image_file = st.file_uploader( | |
| "Upload image", | |
| type=["png", "jpg", "jpeg"], | |
| label_visibility="collapsed" | |
| ) | |
| else: # Audio | |
| audio_col1, audio_col2 = st.columns([3, 1]) | |
| with audio_col1: | |
| audio_method = st.radio( | |
| "Method", | |
| ["ποΈ Record", "π Upload"], | |
| horizontal=True, | |
| label_visibility="collapsed" | |
| ) | |
| with audio_col2: | |
| language = st.selectbox( | |
| "Lang", | |
| ["en", "ar"], | |
| format_func=lambda x: "EN" if x == "en" else "AR", | |
| label_visibility="collapsed" | |
| ) | |
| with search_col3: | |
| top_k = st.selectbox( | |
| "Results", | |
| [10, 20, 30, 50], | |
| label_visibility="collapsed" | |
| ) | |
| # ---------------------- SEARCH INPUT ROW 2 (for Image/Audio) ---------------------- | |
| if search_type == "πΌοΈ Image" and image_file: | |
| preview_col, btn_col = st.columns([1, 3]) | |
| with preview_col: | |
| st.image(image_file, width=100) | |
| with btn_col: | |
| search_btn = st.button("π Search", type="primary") | |
| if search_btn: | |
| with st.spinner("Analyzing..."): | |
| try: | |
| image_file.seek(0) | |
| files = {"image": (image_file.name, image_file, image_file.type)} | |
| response = requests.post(API_URL_IMAGE, files=files, params={"top_k": top_k}, timeout=60) | |
| if response.status_code == 200: | |
| data = response.json() | |
| st.session_state.search_results = build_products_list(data) | |
| st.session_state.prediction_info = { | |
| "type": "image", | |
| "category": data.get("predicted_category", "Unknown"), | |
| "confidence": data.get("confidence_score", 0) | |
| } | |
| st.rerun() | |
| else: | |
| st.error(f"Error: {response.status_code}") | |
| except Exception as e: | |
| st.error(str(e)) | |
| elif search_type == "π€ Audio": | |
| audio_data = None | |
| audio_filename = "recording.wav" | |
| if audio_method == "ποΈ Record": | |
| recorded = st.audio_input("Record", key=f"rec_{st.session_state.audio_key}", label_visibility="collapsed") | |
| if recorded: | |
| audio_data = recorded | |
| else: | |
| uploaded = st.file_uploader("Upload", type=["wav", "mp3", "m4a"], key=f"up_{st.session_state.audio_key}", label_visibility="collapsed") | |
| if uploaded: | |
| audio_data = uploaded | |
| audio_filename = uploaded.name | |
| if audio_data: | |
| col1, col2 = st.columns([1, 2]) | |
| with col1: | |
| st.audio(audio_data) | |
| with col2: | |
| if st.button("π Search", type="primary"): | |
| with st.spinner("Transcribing..."): | |
| try: | |
| audio_data.seek(0) | |
| files = {"audio": (audio_filename, audio_data, "audio/wav")} | |
| response = requests.post(API_URL_AUDIO, files=files, params={"top_k": top_k, "language": language}, timeout=60) | |
| if response.status_code == 200: | |
| data = response.json() | |
| st.session_state.search_results = build_products_list(data) | |
| st.session_state.prediction_info = { | |
| "type": "audio", | |
| "transcription": data.get("caption", "") | |
| } | |
| st.session_state.audio_key += 1 | |
| st.rerun() | |
| else: | |
| st.error(f"Error: {response.status_code}") | |
| except Exception as e: | |
| st.error(str(e)) | |
| elif search_type == "π Text": | |
| if st.button("π Search", type="primary") or (query and st.session_state.get("last_query") != query): | |
| if query and query.strip(): | |
| with st.spinner("Searching..."): | |
| try: | |
| response = requests.post(API_URL_TEXT, params={"query": query, "top_k": top_k}, timeout=30) | |
| if response.status_code == 200: | |
| data = response.json() | |
| st.session_state.search_results = build_products_list(data) | |
| st.session_state.prediction_info = {"type": "text", "query": query} | |
| st.session_state.last_query = query | |
| st.rerun() | |
| else: | |
| st.error(f"Error: {response.status_code}") | |
| except Exception as e: | |
| st.error(str(e)) | |
| # ---------------------- CLEAR BUTTON ---------------------- | |
| if st.session_state.search_results and not st.session_state.selected_product: | |
| if st.button("ποΈ Clear", type="secondary"): | |
| reset_search() | |
| st.rerun() | |
| # ---------------------- DISPLAY RESULTS ---------------------- | |
| st.markdown("---") | |
| # Show product details if a product is selected | |
| if st.session_state.selected_product: | |
| display_product_details(st.session_state.selected_product) | |
| elif st.session_state.search_results: | |
| display_products( | |
| st.session_state.search_results, | |
| st.session_state.prediction_info | |
| ) | |
| else: | |
| # Show random products before any search | |
| display_random_products() |