Raij-AI-UI / src /streamlit_app.py
Hamdy005's picture
Update src/streamlit_app.py
ecad91b verified
import streamlit as st
import requests
from pathlib import Path
# ---------------------- CONFIG ----------------------
API_BASE_URL = "https://hamdy005-raij-ai.hf.space"
API_URL_TEXT = f"{API_BASE_URL}/search/text"
API_URL_IMAGE = f"{API_BASE_URL}/search/image"
API_URL_AUDIO = f"{API_BASE_URL}/search/audio"
API_URL_PRODUCT = f"{API_BASE_URL}/product"
API_URL_RANDOM = f"{API_BASE_URL}/products/random"
# ---------------------- PAGE CONFIG ----------------------
st.set_page_config(
page_title="AI Smart Search",
page_icon="πŸ”Ž",
layout="wide",
initial_sidebar_state="collapsed"
)
# ---------------------- LOAD CSS ----------------------
def load_css():
css_file = Path(__file__).parent / "styles.css"
if css_file.exists():
with open(css_file) as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
load_css()
# ---------------------- SESSION STATE ----------------------
if "audio_key" not in st.session_state:
st.session_state.audio_key = 0
if "search_results" not in st.session_state:
st.session_state.search_results = None
if "prediction_info" not in st.session_state:
st.session_state.prediction_info = None
if "selected_product" not in st.session_state:
st.session_state.selected_product = None
# ---------------------- HELPER FUNCTIONS ----------------------
def build_products_list(data):
"""Build products list from API response"""
products = []
image_urls = data.get("image_urls", [])
prices = data.get("prices", [])
for idx, (pid, title) in enumerate(zip(
data.get("product_ids", []),
data.get("titles", []),
)):
products.append({
"id": pid,
"title": title,
"image_url": image_urls[idx] if idx < len(image_urls) else None,
"price": prices[idx] if idx < len(prices) else None
})
return products
def display_products(products, prediction_info=None):
"""Display products in a grid layout"""
# Show prediction info as a compact banner
if prediction_info and prediction_info.get("type") != "text":
if prediction_info.get("type") == "image":
st.info(f"πŸ–ΌοΈ **Detected Category:** {prediction_info.get('category', 'N/A')} ({prediction_info.get('confidence', 0):.1%} confidence)")
elif prediction_info.get("type") == "audio":
st.info(f"🎀 **Transcription:** \"{prediction_info.get('transcription', 'N/A')}\"")
# Products header
st.markdown(f"### πŸ›οΈ {len(products)} Products Found")
# Create grid layout - 4 columns for more compact display
cols_per_row = 4
for i in range(0, len(products), cols_per_row):
cols = st.columns(cols_per_row)
for j, col in enumerate(cols):
if i + j < len(products):
product = products[i + j]
with col:
with st.container(border=True):
# Show image first if available
if product.get('image_url'):
try:
st.image(product['image_url'], width="stretch")
except:
st.markdown("πŸ–ΌοΈ *No image*")
# Product title (compact)
title = product.get('title', 'N/A')
display_title = title[:40] + '...' if len(title) > 40 else title
st.markdown(f"**{display_title}**")
# Price
price = product.get('price')
if price:
st.markdown(f"<span style='color: #2ecc71; font-weight: bold;'>${price:.2f}</span>", unsafe_allow_html=True)
# View Details button
if st.button("View Details", key=f"view_{product.get('id')}", type="secondary"):
st.session_state.selected_product = product.get('id')
st.rerun()
def display_product_details(product_id):
"""Display detailed product information"""
# Back button
if st.button("← Back to Results", type="secondary"):
st.session_state.selected_product = None
st.rerun()
# Fetch product details from API
with st.spinner("Loading product details..."):
try:
response = requests.get(f"{API_URL_PRODUCT}/{product_id}", timeout=30)
if response.status_code == 200:
product = response.json()
if product.get("error"):
st.error("Product not found")
return
# Layout: Image on left, details on right
col1, col2 = st.columns([1, 2])
with col1:
# Product images
images = product.get("images", [])
if images:
st.image(images[0], width="stretch")
# Show thumbnails if multiple images
if len(images) > 1:
thumb_cols = st.columns(min(4, len(images)))
for idx, img_url in enumerate(images[:4]):
with thumb_cols[idx]:
st.image(img_url, width=80)
else:
st.markdown("πŸ–ΌοΈ *No image available*")
with col2:
# Product title
st.markdown(f"## {product.get('title', 'N/A')}")
# SKU
sku = product.get('sku')
if sku:
st.caption(f"SKU: {sku}")
st.markdown("---")
# Price section
price = product.get('price')
old_price = product.get('old_price')
if old_price and old_price > price:
discount = int((1 - price / old_price) * 100)
st.markdown(f"""
<div style='margin: 10px 0;'>
<span style='text-decoration: line-through; color: #888; font-size: 1.2rem;'>${old_price:.2f}</span>
<span style='color: #e74c3c; font-size: 1.8rem; font-weight: bold; margin-left: 10px;'>${price:.2f}</span>
<span style='background: #e74c3c; color: white; padding: 3px 8px; border-radius: 4px; margin-left: 10px; font-size: 0.9rem;'>-{discount}%</span>
</div>
""", unsafe_allow_html=True)
elif price:
st.markdown(f"<span style='color: #2ecc71; font-size: 1.8rem; font-weight: bold;'>${price:.2f}</span>", unsafe_allow_html=True)
else:
st.markdown("*Price not available*")
st.markdown("---")
# Stock status
stock = product.get('stock', 0)
if stock > 0:
st.success(f"βœ… In Stock ({stock} available)")
else:
st.error("❌ Out of Stock")
st.markdown("---")
# Description
st.markdown("### πŸ“ Description")
description = product.get('description', 'No description available.')
st.write(description)
# Tags
tags = product.get('tags', [])
if tags:
st.markdown("### 🏷️ Tags")
st.write(" β€’ ".join(tags))
else:
st.error(f"Error loading product: {response.status_code}")
except Exception as e:
st.error(f"Error: {str(e)}")
def reset_search():
"""Reset search state"""
st.session_state.search_results = None
st.session_state.prediction_info = None
st.session_state.selected_product = None
def fetch_random_products(limit=10):
"""Fetch products from API"""
try:
response = requests.get(API_URL_RANDOM, params={"limit": limit}, timeout=30)
if response.status_code == 200:
data = response.json()
return data.get("products", [])
except Exception as e:
print(f"Error fetching products: {e}")
return []
def display_random_products():
"""Display products in a grid layout"""
st.markdown("### πŸ›οΈ Discover Products")
# Cache products in session state to avoid fetching on every rerun
if "random_products" not in st.session_state:
st.session_state.random_products = fetch_random_products(10)
products = st.session_state.random_products
if not products:
st.info("No products available. Try searching instead!")
return
# Refresh button
if st.button("πŸ”„ Refresh Products", type="secondary"):
st.session_state.random_products = fetch_random_products(10)
st.rerun()
# Create grid layout - 4 columns
cols_per_row = 4
for i in range(0, len(products), cols_per_row):
cols = st.columns(cols_per_row)
for j, col in enumerate(cols):
if i + j < len(products):
product = products[i + j]
with col:
with st.container(border=True):
# Show image first if available
if product.get('image_url'):
try:
st.image(product['image_url'], width="stretch")
except:
st.markdown("πŸ–ΌοΈ *No image*")
# Product title (compact)
title = product.get('title', 'N/A')
display_title = title[:40] + '...' if len(title) > 40 else title
st.markdown(f"**{display_title}**")
# Price
price = product.get('price')
if price:
st.markdown(f"<span style='color: #2ecc71; font-weight: bold;'>${price:.2f}</span>", unsafe_allow_html=True)
# View Details button
if st.button("View Details", key=f"random_{product.get('id')}", type="secondary"):
st.session_state.selected_product = product.get('id')
st.rerun()
# ---------------------- HEADER ----------------------
st.markdown("""
<div class="main-header">
<h1>πŸ”Ž AI Smart Search</h1>
<p>Search by Text, Image, or Voice</p>
</div>
""", unsafe_allow_html=True)
# ---------------------- COMPACT SEARCH BAR ----------------------
# All search controls in one row
search_col1, search_col2, search_col3 = st.columns([1, 4, 1])
with search_col1:
search_type = st.selectbox(
"Type",
["πŸ“ Text", "πŸ–ΌοΈ Image", "🎀 Audio"],
label_visibility="collapsed"
)
with search_col2:
# Dynamic input based on search type
if search_type == "πŸ“ Text":
query = st.text_input(
"Search",
placeholder="Search for products...",
label_visibility="collapsed"
)
elif search_type == "πŸ–ΌοΈ Image":
image_file = st.file_uploader(
"Upload image",
type=["png", "jpg", "jpeg"],
label_visibility="collapsed"
)
else: # Audio
audio_col1, audio_col2 = st.columns([3, 1])
with audio_col1:
audio_method = st.radio(
"Method",
["πŸŽ™οΈ Record", "πŸ“ Upload"],
horizontal=True,
label_visibility="collapsed"
)
with audio_col2:
language = st.selectbox(
"Lang",
["en", "ar"],
format_func=lambda x: "EN" if x == "en" else "AR",
label_visibility="collapsed"
)
with search_col3:
top_k = st.selectbox(
"Results",
[10, 20, 30, 50],
label_visibility="collapsed"
)
# ---------------------- SEARCH INPUT ROW 2 (for Image/Audio) ----------------------
if search_type == "πŸ–ΌοΈ Image" and image_file:
preview_col, btn_col = st.columns([1, 3])
with preview_col:
st.image(image_file, width=100)
with btn_col:
search_btn = st.button("πŸ” Search", type="primary")
if search_btn:
with st.spinner("Analyzing..."):
try:
image_file.seek(0)
files = {"image": (image_file.name, image_file, image_file.type)}
response = requests.post(API_URL_IMAGE, files=files, params={"top_k": top_k}, timeout=60)
if response.status_code == 200:
data = response.json()
st.session_state.search_results = build_products_list(data)
st.session_state.prediction_info = {
"type": "image",
"category": data.get("predicted_category", "Unknown"),
"confidence": data.get("confidence_score", 0)
}
st.rerun()
else:
st.error(f"Error: {response.status_code}")
except Exception as e:
st.error(str(e))
elif search_type == "🎀 Audio":
audio_data = None
audio_filename = "recording.wav"
if audio_method == "πŸŽ™οΈ Record":
recorded = st.audio_input("Record", key=f"rec_{st.session_state.audio_key}", label_visibility="collapsed")
if recorded:
audio_data = recorded
else:
uploaded = st.file_uploader("Upload", type=["wav", "mp3", "m4a"], key=f"up_{st.session_state.audio_key}", label_visibility="collapsed")
if uploaded:
audio_data = uploaded
audio_filename = uploaded.name
if audio_data:
col1, col2 = st.columns([1, 2])
with col1:
st.audio(audio_data)
with col2:
if st.button("πŸ” Search", type="primary"):
with st.spinner("Transcribing..."):
try:
audio_data.seek(0)
files = {"audio": (audio_filename, audio_data, "audio/wav")}
response = requests.post(API_URL_AUDIO, files=files, params={"top_k": top_k, "language": language}, timeout=60)
if response.status_code == 200:
data = response.json()
st.session_state.search_results = build_products_list(data)
st.session_state.prediction_info = {
"type": "audio",
"transcription": data.get("caption", "")
}
st.session_state.audio_key += 1
st.rerun()
else:
st.error(f"Error: {response.status_code}")
except Exception as e:
st.error(str(e))
elif search_type == "πŸ“ Text":
if st.button("πŸ” Search", type="primary") or (query and st.session_state.get("last_query") != query):
if query and query.strip():
with st.spinner("Searching..."):
try:
response = requests.post(API_URL_TEXT, params={"query": query, "top_k": top_k}, timeout=30)
if response.status_code == 200:
data = response.json()
st.session_state.search_results = build_products_list(data)
st.session_state.prediction_info = {"type": "text", "query": query}
st.session_state.last_query = query
st.rerun()
else:
st.error(f"Error: {response.status_code}")
except Exception as e:
st.error(str(e))
# ---------------------- CLEAR BUTTON ----------------------
if st.session_state.search_results and not st.session_state.selected_product:
if st.button("πŸ—‘οΈ Clear", type="secondary"):
reset_search()
st.rerun()
# ---------------------- DISPLAY RESULTS ----------------------
st.markdown("---")
# Show product details if a product is selected
if st.session_state.selected_product:
display_product_details(st.session_state.selected_product)
elif st.session_state.search_results:
display_products(
st.session_state.search_results,
st.session_state.prediction_info
)
else:
# Show random products before any search
display_random_products()