chrome_models / 74 /app.py
dejanseo's picture
Upload app.py
8e2aea6 verified
import streamlit as st
import base64
import time
import numpy as np
import sentencepiece as spm
from ai_edge_litert.interpreter import Interpreter
from selenium import webdriver
from selenium.webdriver.chrome.service import Service as ChromeService
from selenium.webdriver.chrome.options import Options as ChromeOptions
import common_quality_data_pb2 as apc_pb2
import os
# --- Paths ---
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
EMBEDDER_PATH = os.path.join(BASE_DIR, "passage_embedder", "model.tflite")
CLASSIFIER_PATH = os.path.join(BASE_DIR, "shopping_classifier", "model.tflite")
SPM_PATH = os.path.join(BASE_DIR, "passage_embedder", "sentencepiece.model")
CHROME_CANARY = os.path.expandvars(
r"%LOCALAPPDATA%\Google\Chrome SxS\Application\chrome.exe"
)
INPUT_WINDOW_SIZE = 64
EMBEDDING_DIM = 768
MAX_WORDS_PER_PASSAGE = 100
MIN_WORDS_PER_PASSAGE = 5
MAX_PASSAGES = 10
# --- Load models once ---
@st.cache_resource
def load_sp():
sp = spm.SentencePieceProcessor()
sp.Load(SPM_PATH)
return sp
@st.cache_resource
def load_embedder():
interp = Interpreter(model_path=EMBEDDER_PATH)
interp.allocate_tensors()
return interp
@st.cache_resource
def load_classifier():
interp = Interpreter(model_path=CLASSIFIER_PATH)
interp.allocate_tensors()
return interp
# --- Text extraction from AnnotatedPageContent proto ---
def extract_text_from_node(node):
"""Recursively extract text items from ContentNode tree."""
items = []
attrs = node.content_attributes
if attrs.HasField("text_data"):
text = attrs.text_data.text_content.strip()
if text:
items.append(text)
elif attrs.HasField("table_data"):
text = attrs.table_data.table_name.strip()
if text:
items.append(text)
elif attrs.HasField("image_data"):
text = attrs.image_data.image_caption.strip()
if text:
items.append(text)
for child in node.children_nodes:
items.extend(extract_text_from_node(child))
return items
def chunk_passages(text_items, max_words=MAX_WORDS_PER_PASSAGE,
min_words=MIN_WORDS_PER_PASSAGE, max_passages=MAX_PASSAGES):
"""Greedy word-count chunking matching Chrome's algorithm."""
passages = []
current = []
current_word_count = 0
for item in text_items:
words = item.split()
item_word_count = len(words)
if item_word_count < min_words:
current.append(item)
current_word_count += item_word_count
else:
if current_word_count + item_word_count > max_words and current:
passages.append(" ".join(current))
current = [item]
current_word_count = item_word_count
else:
current.append(item)
current_word_count += item_word_count
if current_word_count >= max_words:
passages.append(" ".join(current))
current = []
current_word_count = 0
if len(passages) >= max_passages:
break
if current and len(passages) < max_passages:
passages.append(" ".join(current))
return passages[:max_passages]
# --- Tokenization ---
def tokenize(sp, text):
"""SentencePiece encode, append EOS if room, resize to INPUT_WINDOW_SIZE."""
token_ids = sp.Encode(text)
if len(token_ids) < INPUT_WINDOW_SIZE:
token_ids.append(sp.eos_id())
token_ids = token_ids[:INPUT_WINDOW_SIZE]
# Zero-pad
token_ids += [0] * (INPUT_WINDOW_SIZE - len(token_ids))
return np.array(token_ids, dtype=np.int32).reshape(1, INPUT_WINDOW_SIZE)
# --- Embedding ---
def embed(interp, token_ids):
"""Run passage embedder: int32[1,64] -> float32[1,768]."""
input_details = interp.get_input_details()
output_details = interp.get_output_details()
interp.set_tensor(input_details[0]["index"], token_ids)
interp.invoke()
return interp.get_tensor(output_details[0]["index"]).copy()
# --- Classification ---
def classify(interp, input_vector):
"""Run shopping classifier: float32[1,1536] -> float32[1,1]."""
input_details = interp.get_input_details()
output_details = interp.get_output_details()
interp.set_tensor(input_details[0]["index"], input_vector)
interp.invoke()
return float(interp.get_tensor(output_details[0]["index"])[0][0])
# --- CDP page extraction ---
def fetch_page_content(url):
"""Use Chrome Canary + Selenium CDP to get AnnotatedPageContent."""
options = ChromeOptions()
options.binary_location = CHROME_CANARY
options.add_argument("--headless=new")
options.add_argument("--disable-gpu")
options.add_argument("--no-sandbox")
driver = webdriver.Chrome(options=options)
try:
driver.get(url)
# Wait for content to settle (Chrome uses 5s delay)
time.sleep(5)
# Try AnnotatedPageContent via CDP
apc_data = None
try:
result = driver.execute_cdp_cmd(
"Page.getAnnotatedPageContent",
{"includeActionableInformation": True},
)
apc_data = base64.b64decode(result["content"])
except Exception as e:
st.warning(f"CDP AnnotatedPageContent failed: {e}")
# Fallback: get title and innerText
title = driver.title
inner_text = driver.execute_script("return document.body.innerText")
page_url = driver.current_url
finally:
driver.quit()
return apc_data, title, page_url, inner_text
def process_apc(apc_data):
"""Parse AnnotatedPageContent proto and extract title, url, text items."""
apc = apc_pb2.AnnotatedPageContent()
apc.ParseFromString(apc_data)
title = apc.main_frame_data.title
url = apc.main_frame_data.url
text_items = extract_text_from_node(apc.root_node)
return title, url, text_items
def process_fallback(title, url, inner_text):
"""Fallback: split innerText into text items by lines."""
lines = [line.strip() for line in inner_text.split("\n") if line.strip()]
return title, url, lines
# --- Full pipeline ---
def run_pipeline(title, url, text_items, sp, embedder, classifier):
"""Run the full embedding + classification pipeline."""
# 1. Create passages
passages = chunk_passages(text_items)
# 2. Embed title + url
title_url_text = f"{title} - {url}"
title_url_tokens = tokenize(sp, title_url_text)
title_url_emb = embed(embedder, title_url_tokens) # [1, 768]
# 3. Embed passages and mean-pool
if passages:
passage_embeddings = []
for passage in passages:
tokens = tokenize(sp, passage)
emb = embed(embedder, tokens)
passage_embeddings.append(emb[0])
# Mean pooling
mean_pooled = np.mean(passage_embeddings, axis=0, keepdims=True) # [1, 768]
else:
mean_pooled = np.zeros((1, EMBEDDING_DIM), dtype=np.float32)
# 4. Concatenate: [title_url(768) | passages_mean(768)] = [1, 1536]
input_vector = np.concatenate([title_url_emb, mean_pooled], axis=1).astype(np.float32)
# 5. Classify
score = classify(classifier, input_vector)
return score, passages
# --- Streamlit UI ---
st.set_page_config(page_title="Shopping Classifier", layout="wide")
st.html("""
<style>
.stButton > button[kind="primary"] {
background-color: #2e7d32;
border-color: #2e7d32;
}
.stButton > button[kind="primary"]:hover {
background-color: #1b5e20;
border-color: #1b5e20;
}
</style>
""")
st.subheader("Shopping Page Classifier")
#st.caption("Using Chrome's OPTIMIZATION_TARGET_SHOPPING_CLASSIFIER model")
url = st.text_input("Enter URL", placeholder="https://www.amazon.com/dp/B0...")
if st.button("Classify", type="primary") and url:
sp = load_sp()
embedder = load_embedder()
classifier = load_classifier()
with st.spinner("Loading page in Chrome Canary..."):
apc_data, fallback_title, page_url, inner_text = fetch_page_content(url)
# Process page content
used_method = None
if apc_data:
try:
title, resolved_url, text_items = process_apc(apc_data)
used_method = "CDP AnnotatedPageContent"
except Exception as e:
st.warning(f"Proto parse failed: {e}, falling back to innerText")
title, resolved_url, text_items = process_fallback(
fallback_title, page_url, inner_text
)
used_method = "innerText fallback"
else:
title, resolved_url, text_items = process_fallback(
fallback_title, page_url, inner_text
)
used_method = "innerText fallback"
with st.spinner("Running inference..."):
score, passages = run_pipeline(
title, resolved_url, text_items, sp, embedder, classifier
)
# --- Results ---
threshold = 0.5
is_shopping = score >= threshold
col1, col2 = st.columns(2)
with col1:
st.metric("Score", f"{score:.4f}")
with col2:
if is_shopping:
st.success(f"SHOPPING PAGE (>= {threshold})")
else:
st.info(f"NOT SHOPPING (< {threshold})")
# Details
with st.expander("Details"):
st.write(f"**Method:** {used_method}")
st.write(f"**Title:** {title}")
st.write(f"**URL:** {resolved_url}")
st.write(f"**Text items extracted:** {len(text_items)}")
st.write(f"**Passages created:** {len(passages)}")
passages_json = {f"passage_{i+1}": p for i, p in enumerate(passages)}
st.json(passages_json)