| import streamlit as st |
| import base64 |
| import time |
| import numpy as np |
| import sentencepiece as spm |
| from ai_edge_litert.interpreter import Interpreter |
| from selenium import webdriver |
| from selenium.webdriver.chrome.service import Service as ChromeService |
| from selenium.webdriver.chrome.options import Options as ChromeOptions |
| import common_quality_data_pb2 as apc_pb2 |
| import os |
|
|
| |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) |
| EMBEDDER_PATH = os.path.join(BASE_DIR, "passage_embedder", "model.tflite") |
| CLASSIFIER_PATH = os.path.join(BASE_DIR, "shopping_classifier", "model.tflite") |
| SPM_PATH = os.path.join(BASE_DIR, "passage_embedder", "sentencepiece.model") |
| CHROME_CANARY = os.path.expandvars( |
| r"%LOCALAPPDATA%\Google\Chrome SxS\Application\chrome.exe" |
| ) |
|
|
| INPUT_WINDOW_SIZE = 64 |
| EMBEDDING_DIM = 768 |
| MAX_WORDS_PER_PASSAGE = 100 |
| MIN_WORDS_PER_PASSAGE = 5 |
| MAX_PASSAGES = 10 |
|
|
|
|
| |
| @st.cache_resource |
| def load_sp(): |
| sp = spm.SentencePieceProcessor() |
| sp.Load(SPM_PATH) |
| return sp |
|
|
|
|
| @st.cache_resource |
| def load_embedder(): |
| interp = Interpreter(model_path=EMBEDDER_PATH) |
| interp.allocate_tensors() |
| return interp |
|
|
|
|
| @st.cache_resource |
| def load_classifier(): |
| interp = Interpreter(model_path=CLASSIFIER_PATH) |
| interp.allocate_tensors() |
| return interp |
|
|
|
|
| |
| def extract_text_from_node(node): |
| """Recursively extract text items from ContentNode tree.""" |
| items = [] |
| attrs = node.content_attributes |
| if attrs.HasField("text_data"): |
| text = attrs.text_data.text_content.strip() |
| if text: |
| items.append(text) |
| elif attrs.HasField("table_data"): |
| text = attrs.table_data.table_name.strip() |
| if text: |
| items.append(text) |
| elif attrs.HasField("image_data"): |
| text = attrs.image_data.image_caption.strip() |
| if text: |
| items.append(text) |
| for child in node.children_nodes: |
| items.extend(extract_text_from_node(child)) |
| return items |
|
|
|
|
| def chunk_passages(text_items, max_words=MAX_WORDS_PER_PASSAGE, |
| min_words=MIN_WORDS_PER_PASSAGE, max_passages=MAX_PASSAGES): |
| """Greedy word-count chunking matching Chrome's algorithm.""" |
| passages = [] |
| current = [] |
| current_word_count = 0 |
|
|
| for item in text_items: |
| words = item.split() |
| item_word_count = len(words) |
|
|
| if item_word_count < min_words: |
| current.append(item) |
| current_word_count += item_word_count |
| else: |
| if current_word_count + item_word_count > max_words and current: |
| passages.append(" ".join(current)) |
| current = [item] |
| current_word_count = item_word_count |
| else: |
| current.append(item) |
| current_word_count += item_word_count |
|
|
| if current_word_count >= max_words: |
| passages.append(" ".join(current)) |
| current = [] |
| current_word_count = 0 |
|
|
| if len(passages) >= max_passages: |
| break |
|
|
| if current and len(passages) < max_passages: |
| passages.append(" ".join(current)) |
|
|
| return passages[:max_passages] |
|
|
|
|
| |
| def tokenize(sp, text): |
| """SentencePiece encode, append EOS if room, resize to INPUT_WINDOW_SIZE.""" |
| token_ids = sp.Encode(text) |
| if len(token_ids) < INPUT_WINDOW_SIZE: |
| token_ids.append(sp.eos_id()) |
| token_ids = token_ids[:INPUT_WINDOW_SIZE] |
| |
| token_ids += [0] * (INPUT_WINDOW_SIZE - len(token_ids)) |
| return np.array(token_ids, dtype=np.int32).reshape(1, INPUT_WINDOW_SIZE) |
|
|
|
|
| |
| def embed(interp, token_ids): |
| """Run passage embedder: int32[1,64] -> float32[1,768].""" |
| input_details = interp.get_input_details() |
| output_details = interp.get_output_details() |
| interp.set_tensor(input_details[0]["index"], token_ids) |
| interp.invoke() |
| return interp.get_tensor(output_details[0]["index"]).copy() |
|
|
|
|
| |
| def classify(interp, input_vector): |
| """Run shopping classifier: float32[1,1536] -> float32[1,1].""" |
| input_details = interp.get_input_details() |
| output_details = interp.get_output_details() |
| interp.set_tensor(input_details[0]["index"], input_vector) |
| interp.invoke() |
| return float(interp.get_tensor(output_details[0]["index"])[0][0]) |
|
|
|
|
| |
| def fetch_page_content(url): |
| """Use Chrome Canary + Selenium CDP to get AnnotatedPageContent.""" |
| options = ChromeOptions() |
| options.binary_location = CHROME_CANARY |
| options.add_argument("--headless=new") |
| options.add_argument("--disable-gpu") |
| options.add_argument("--no-sandbox") |
|
|
| driver = webdriver.Chrome(options=options) |
| try: |
| driver.get(url) |
| |
| time.sleep(5) |
|
|
| |
| apc_data = None |
| try: |
| result = driver.execute_cdp_cmd( |
| "Page.getAnnotatedPageContent", |
| {"includeActionableInformation": True}, |
| ) |
| apc_data = base64.b64decode(result["content"]) |
| except Exception as e: |
| st.warning(f"CDP AnnotatedPageContent failed: {e}") |
|
|
| |
| title = driver.title |
| inner_text = driver.execute_script("return document.body.innerText") |
| page_url = driver.current_url |
| finally: |
| driver.quit() |
|
|
| return apc_data, title, page_url, inner_text |
|
|
|
|
| def process_apc(apc_data): |
| """Parse AnnotatedPageContent proto and extract title, url, text items.""" |
| apc = apc_pb2.AnnotatedPageContent() |
| apc.ParseFromString(apc_data) |
|
|
| title = apc.main_frame_data.title |
| url = apc.main_frame_data.url |
| text_items = extract_text_from_node(apc.root_node) |
|
|
| return title, url, text_items |
|
|
|
|
| def process_fallback(title, url, inner_text): |
| """Fallback: split innerText into text items by lines.""" |
| lines = [line.strip() for line in inner_text.split("\n") if line.strip()] |
| return title, url, lines |
|
|
|
|
| |
| def run_pipeline(title, url, text_items, sp, embedder, classifier): |
| """Run the full embedding + classification pipeline.""" |
| |
| passages = chunk_passages(text_items) |
|
|
| |
| title_url_text = f"{title} - {url}" |
| title_url_tokens = tokenize(sp, title_url_text) |
| title_url_emb = embed(embedder, title_url_tokens) |
|
|
| |
| if passages: |
| passage_embeddings = [] |
| for passage in passages: |
| tokens = tokenize(sp, passage) |
| emb = embed(embedder, tokens) |
| passage_embeddings.append(emb[0]) |
| |
| mean_pooled = np.mean(passage_embeddings, axis=0, keepdims=True) |
| else: |
| mean_pooled = np.zeros((1, EMBEDDING_DIM), dtype=np.float32) |
|
|
| |
| input_vector = np.concatenate([title_url_emb, mean_pooled], axis=1).astype(np.float32) |
|
|
| |
| score = classify(classifier, input_vector) |
|
|
| return score, passages |
|
|
|
|
| |
| st.set_page_config(page_title="Shopping Classifier", layout="wide") |
|
|
| st.html(""" |
| <style> |
| .stButton > button[kind="primary"] { |
| background-color: #2e7d32; |
| border-color: #2e7d32; |
| } |
| .stButton > button[kind="primary"]:hover { |
| background-color: #1b5e20; |
| border-color: #1b5e20; |
| } |
| </style> |
| """) |
| st.subheader("Shopping Page Classifier") |
| |
|
|
| url = st.text_input("Enter URL", placeholder="https://www.amazon.com/dp/B0...") |
|
|
| if st.button("Classify", type="primary") and url: |
| sp = load_sp() |
| embedder = load_embedder() |
| classifier = load_classifier() |
|
|
| with st.spinner("Loading page in Chrome Canary..."): |
| apc_data, fallback_title, page_url, inner_text = fetch_page_content(url) |
|
|
| |
| used_method = None |
| if apc_data: |
| try: |
| title, resolved_url, text_items = process_apc(apc_data) |
| used_method = "CDP AnnotatedPageContent" |
| except Exception as e: |
| st.warning(f"Proto parse failed: {e}, falling back to innerText") |
| title, resolved_url, text_items = process_fallback( |
| fallback_title, page_url, inner_text |
| ) |
| used_method = "innerText fallback" |
| else: |
| title, resolved_url, text_items = process_fallback( |
| fallback_title, page_url, inner_text |
| ) |
| used_method = "innerText fallback" |
|
|
| with st.spinner("Running inference..."): |
| score, passages = run_pipeline( |
| title, resolved_url, text_items, sp, embedder, classifier |
| ) |
|
|
| |
| threshold = 0.5 |
| is_shopping = score >= threshold |
| col1, col2 = st.columns(2) |
| with col1: |
| st.metric("Score", f"{score:.4f}") |
| with col2: |
| if is_shopping: |
| st.success(f"SHOPPING PAGE (>= {threshold})") |
| else: |
| st.info(f"NOT SHOPPING (< {threshold})") |
|
|
| |
| with st.expander("Details"): |
| st.write(f"**Method:** {used_method}") |
| st.write(f"**Title:** {title}") |
| st.write(f"**URL:** {resolved_url}") |
| st.write(f"**Text items extracted:** {len(text_items)}") |
| st.write(f"**Passages created:** {len(passages)}") |
| passages_json = {f"passage_{i+1}": p for i, p in enumerate(passages)} |
| st.json(passages_json) |
|
|