import streamlit as st import base64 import time import numpy as np import sentencepiece as spm from ai_edge_litert.interpreter import Interpreter from selenium import webdriver from selenium.webdriver.chrome.service import Service as ChromeService from selenium.webdriver.chrome.options import Options as ChromeOptions import common_quality_data_pb2 as apc_pb2 import os # --- Paths --- BASE_DIR = os.path.dirname(os.path.abspath(__file__)) EMBEDDER_PATH = os.path.join(BASE_DIR, "passage_embedder", "model.tflite") CLASSIFIER_PATH = os.path.join(BASE_DIR, "shopping_classifier", "model.tflite") SPM_PATH = os.path.join(BASE_DIR, "passage_embedder", "sentencepiece.model") CHROME_CANARY = os.path.expandvars( r"%LOCALAPPDATA%\Google\Chrome SxS\Application\chrome.exe" ) INPUT_WINDOW_SIZE = 64 EMBEDDING_DIM = 768 MAX_WORDS_PER_PASSAGE = 100 MIN_WORDS_PER_PASSAGE = 5 MAX_PASSAGES = 10 # --- Load models once --- @st.cache_resource def load_sp(): sp = spm.SentencePieceProcessor() sp.Load(SPM_PATH) return sp @st.cache_resource def load_embedder(): interp = Interpreter(model_path=EMBEDDER_PATH) interp.allocate_tensors() return interp @st.cache_resource def load_classifier(): interp = Interpreter(model_path=CLASSIFIER_PATH) interp.allocate_tensors() return interp # --- Text extraction from AnnotatedPageContent proto --- def extract_text_from_node(node): """Recursively extract text items from ContentNode tree.""" items = [] attrs = node.content_attributes if attrs.HasField("text_data"): text = attrs.text_data.text_content.strip() if text: items.append(text) elif attrs.HasField("table_data"): text = attrs.table_data.table_name.strip() if text: items.append(text) elif attrs.HasField("image_data"): text = attrs.image_data.image_caption.strip() if text: items.append(text) for child in node.children_nodes: items.extend(extract_text_from_node(child)) return items def chunk_passages(text_items, max_words=MAX_WORDS_PER_PASSAGE, min_words=MIN_WORDS_PER_PASSAGE, max_passages=MAX_PASSAGES): """Greedy word-count chunking matching Chrome's algorithm.""" passages = [] current = [] current_word_count = 0 for item in text_items: words = item.split() item_word_count = len(words) if item_word_count < min_words: current.append(item) current_word_count += item_word_count else: if current_word_count + item_word_count > max_words and current: passages.append(" ".join(current)) current = [item] current_word_count = item_word_count else: current.append(item) current_word_count += item_word_count if current_word_count >= max_words: passages.append(" ".join(current)) current = [] current_word_count = 0 if len(passages) >= max_passages: break if current and len(passages) < max_passages: passages.append(" ".join(current)) return passages[:max_passages] # --- Tokenization --- def tokenize(sp, text): """SentencePiece encode, append EOS if room, resize to INPUT_WINDOW_SIZE.""" token_ids = sp.Encode(text) if len(token_ids) < INPUT_WINDOW_SIZE: token_ids.append(sp.eos_id()) token_ids = token_ids[:INPUT_WINDOW_SIZE] # Zero-pad token_ids += [0] * (INPUT_WINDOW_SIZE - len(token_ids)) return np.array(token_ids, dtype=np.int32).reshape(1, INPUT_WINDOW_SIZE) # --- Embedding --- def embed(interp, token_ids): """Run passage embedder: int32[1,64] -> float32[1,768].""" input_details = interp.get_input_details() output_details = interp.get_output_details() interp.set_tensor(input_details[0]["index"], token_ids) interp.invoke() return interp.get_tensor(output_details[0]["index"]).copy() # --- Classification --- def classify(interp, input_vector): """Run shopping classifier: float32[1,1536] -> float32[1,1].""" input_details = interp.get_input_details() output_details = interp.get_output_details() interp.set_tensor(input_details[0]["index"], input_vector) interp.invoke() return float(interp.get_tensor(output_details[0]["index"])[0][0]) # --- CDP page extraction --- def fetch_page_content(url): """Use Chrome Canary + Selenium CDP to get AnnotatedPageContent.""" options = ChromeOptions() options.binary_location = CHROME_CANARY options.add_argument("--headless=new") options.add_argument("--disable-gpu") options.add_argument("--no-sandbox") driver = webdriver.Chrome(options=options) try: driver.get(url) # Wait for content to settle (Chrome uses 5s delay) time.sleep(5) # Try AnnotatedPageContent via CDP apc_data = None try: result = driver.execute_cdp_cmd( "Page.getAnnotatedPageContent", {"includeActionableInformation": True}, ) apc_data = base64.b64decode(result["content"]) except Exception as e: st.warning(f"CDP AnnotatedPageContent failed: {e}") # Fallback: get title and innerText title = driver.title inner_text = driver.execute_script("return document.body.innerText") page_url = driver.current_url finally: driver.quit() return apc_data, title, page_url, inner_text def process_apc(apc_data): """Parse AnnotatedPageContent proto and extract title, url, text items.""" apc = apc_pb2.AnnotatedPageContent() apc.ParseFromString(apc_data) title = apc.main_frame_data.title url = apc.main_frame_data.url text_items = extract_text_from_node(apc.root_node) return title, url, text_items def process_fallback(title, url, inner_text): """Fallback: split innerText into text items by lines.""" lines = [line.strip() for line in inner_text.split("\n") if line.strip()] return title, url, lines # --- Full pipeline --- def run_pipeline(title, url, text_items, sp, embedder, classifier): """Run the full embedding + classification pipeline.""" # 1. Create passages passages = chunk_passages(text_items) # 2. Embed title + url title_url_text = f"{title} - {url}" title_url_tokens = tokenize(sp, title_url_text) title_url_emb = embed(embedder, title_url_tokens) # [1, 768] # 3. Embed passages and mean-pool if passages: passage_embeddings = [] for passage in passages: tokens = tokenize(sp, passage) emb = embed(embedder, tokens) passage_embeddings.append(emb[0]) # Mean pooling mean_pooled = np.mean(passage_embeddings, axis=0, keepdims=True) # [1, 768] else: mean_pooled = np.zeros((1, EMBEDDING_DIM), dtype=np.float32) # 4. Concatenate: [title_url(768) | passages_mean(768)] = [1, 1536] input_vector = np.concatenate([title_url_emb, mean_pooled], axis=1).astype(np.float32) # 5. Classify score = classify(classifier, input_vector) return score, passages # --- Streamlit UI --- st.set_page_config(page_title="Shopping Classifier", layout="wide") st.html(""" """) st.subheader("Shopping Page Classifier") #st.caption("Using Chrome's OPTIMIZATION_TARGET_SHOPPING_CLASSIFIER model") url = st.text_input("Enter URL", placeholder="https://www.amazon.com/dp/B0...") if st.button("Classify", type="primary") and url: sp = load_sp() embedder = load_embedder() classifier = load_classifier() with st.spinner("Loading page in Chrome Canary..."): apc_data, fallback_title, page_url, inner_text = fetch_page_content(url) # Process page content used_method = None if apc_data: try: title, resolved_url, text_items = process_apc(apc_data) used_method = "CDP AnnotatedPageContent" except Exception as e: st.warning(f"Proto parse failed: {e}, falling back to innerText") title, resolved_url, text_items = process_fallback( fallback_title, page_url, inner_text ) used_method = "innerText fallback" else: title, resolved_url, text_items = process_fallback( fallback_title, page_url, inner_text ) used_method = "innerText fallback" with st.spinner("Running inference..."): score, passages = run_pipeline( title, resolved_url, text_items, sp, embedder, classifier ) # --- Results --- threshold = 0.5 is_shopping = score >= threshold col1, col2 = st.columns(2) with col1: st.metric("Score", f"{score:.4f}") with col2: if is_shopping: st.success(f"SHOPPING PAGE (>= {threshold})") else: st.info(f"NOT SHOPPING (< {threshold})") # Details with st.expander("Details"): st.write(f"**Method:** {used_method}") st.write(f"**Title:** {title}") st.write(f"**URL:** {resolved_url}") st.write(f"**Text items extracted:** {len(text_items)}") st.write(f"**Passages created:** {len(passages)}") passages_json = {f"passage_{i+1}": p for i, p in enumerate(passages)} st.json(passages_json)