Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| from pathlib import Path | |
| import pandas as pd | |
| import streamlit as st | |
| import snowflake.connector | |
| from cryptography.hazmat.primitives import serialization | |
| from headshot_scraper import download_author_image_for_site | |
| from gpt import CustomGPT | |
| # ------------------------------ | |
| # Helper to fetch env with HF prefix fallback | |
| # ------------------------------ | |
| def get_env(name: str): | |
| """Try HF space secrets (REPO_SECRET_name), else fallback to plain name.""" | |
| return os.environ.get(f"REPO_SECRET_{name}") or os.environ.get(name) | |
| CATALOG_DATA_PATH = Path(__file__).with_name("data.csv") | |
| def load_catalog_data(): | |
| """Load the catalog data (if present) to power dropdown options.""" | |
| if not CATALOG_DATA_PATH.exists(): | |
| st.info("Upload data.csv to populate dropdown options. Using defaults instead.") | |
| return None | |
| encodings = ["utf-8", "utf-8-sig", "latin1"] | |
| last_error = None | |
| for encoding in encodings: | |
| try: | |
| if encoding != "utf-8": | |
| st.info( | |
| f"Reading catalog data with {encoding} encoding fallback.", | |
| ) | |
| return pd.read_csv(CATALOG_DATA_PATH, encoding=encoding) | |
| except UnicodeDecodeError as exc: | |
| last_error = exc | |
| continue | |
| except Exception as exc: # pragma: no cover - UI surfaced warning only | |
| st.warning(f"⚠️ Could not load catalog data from data.csv: {exc}") | |
| return None | |
| st.warning( | |
| "⚠️ Could not load catalog data from data.csv due to encoding issues. " | |
| f"Last error: {last_error}" | |
| ) | |
| return None | |
| def collect_unique_options(df, candidate_columns, split_chars=None): | |
| """ | |
| Return sorted unique values from the first matching column in `candidate_columns`. | |
| If `split_chars` is provided, split string values by those separators before deduping. | |
| """ | |
| if df is None: | |
| return [] | |
| for col in candidate_columns: | |
| if col in df.columns: | |
| series = df[col].dropna() | |
| values = set() | |
| for item in series: | |
| if isinstance(item, str) and split_chars: | |
| parts = re.split(split_chars, item) | |
| values.update(part.strip() for part in parts if part.strip()) | |
| else: | |
| text = str(item).strip() | |
| if text: | |
| values.add(text) | |
| options = sorted(values) | |
| if options: | |
| return options | |
| return [] | |
| # ------------------------------ | |
| # Snowflake connection | |
| # ------------------------------ | |
| def connect_to_snowflake(): | |
| pem = get_env("snowflake_private_key") | |
| if pem is None: | |
| st.warning("⚠️ Missing Snowflake private key. Add it as a HF Secret.") | |
| return None | |
| try: | |
| private_key = serialization.load_pem_private_key( | |
| pem.encode(), | |
| password=None, | |
| ) | |
| except Exception as e: | |
| st.error(f"❌ Could not load Snowflake private key: {e}") | |
| return None | |
| try: | |
| conn = snowflake.connector.connect( | |
| user=get_env("snowflake_user"), | |
| account=get_env("snowflake_account_identifier"), | |
| private_key=private_key, | |
| role=get_env("snowflake_role"), | |
| warehouse=get_env("snowflake_warehouse"), | |
| database=get_env("snowflake_database"), | |
| schema=get_env("snowflake_schema"), | |
| ) | |
| return conn | |
| except Exception as e: | |
| st.error(f"❌ Snowflake connection failed: {e}") | |
| return None | |
| def fetch_sites(conn): | |
| """ | |
| Return a list of dicts: | |
| [{"site_name": ..., "url": ...}, ...] | |
| """ | |
| try: | |
| cur = conn.cursor() | |
| cur.execute( | |
| """ | |
| SELECT DISTINCT | |
| site_name, | |
| url -- Replace with actual URL column if different | |
| FROM analytics.adthrive.SITE_EXTENDED | |
| WHERE site_name IS NOT NULL | |
| AND url IS NOT NULL | |
| ORDER BY site_name | |
| """ | |
| ) | |
| rows = cur.fetchall() | |
| return [{"site_name": r[0], "url": r[1]} for r in rows] | |
| except Exception as e: | |
| st.error(f"Failed to fetch site list: {e}") | |
| return [] | |
| # ------------------------------ | |
| # Streamlit UI setup | |
| # ------------------------------ | |
| st.set_page_config(page_title="Headshot Scraper", page_icon="🧑🍳", layout="wide") | |
| st.title("Headshot / Author Image Scraper") | |
| st.write( | |
| "Select a site from Snowflake (by name) or enter one manually. " | |
| "The scraper will use the stored URL to find the About page and extract the headshot." | |
| ) | |
| # Initialize session state for last_result (so results persist across reruns) | |
| if "last_result" not in st.session_state: | |
| st.session_state["last_result"] = None | |
| if "chat_history" not in st.session_state: | |
| st.session_state["chat_history"] = [] | |
| # ------------------------------ | |
| # Snowflake: connect + dropdown | |
| # ------------------------------ | |
| st.write("🔑 Connecting to Snowflake…") | |
| conn = connect_to_snowflake() | |
| sites = [] | |
| selected_site_name = "" | |
| selected_site_url = "" | |
| if conn: | |
| st.success(f"Connected to Snowflake as {get_env('snowflake_user')}") | |
| sites = fetch_sites(conn) | |
| site_name_options = [""] + [s["site_name"] for s in sites] | |
| selected_site_name = st.selectbox("Select site by name:", site_name_options) | |
| if selected_site_name: | |
| match = next((s for s in sites if s["site_name"] == selected_site_name), None) | |
| if match: | |
| selected_site_url = match["url"] | |
| st.caption(f"URL from Snowflake: {selected_site_url}") | |
| else: | |
| st.warning("No URL found for the selected site.") | |
| else: | |
| st.warning("Snowflake connection not available. Manual entry only.") | |
| # ------------------------------ | |
| # Manual URL entry fallback | |
| # ------------------------------ | |
| manual_entry = st.text_input( | |
| "Or enter a site manually:", | |
| placeholder="damndelicious.net", | |
| ) | |
| # Final URL to be used (Snowflake URL takes precedence) | |
| site_or_url = selected_site_url if selected_site_url else manual_entry | |
| # ------------------------------ | |
| # Scrape button (updates session_state) | |
| # ------------------------------ | |
| if st.button("Scrape headshot"): | |
| if not site_or_url.strip(): | |
| st.error("Please select or enter a site.") | |
| else: | |
| with st.spinner("Scraping…"): | |
| try: | |
| result = download_author_image_for_site( | |
| site_or_url, out_dir="/tmp/author_images" | |
| ) | |
| # Store result so it persists across reruns | |
| st.session_state["last_result"] = result | |
| except Exception as e: | |
| st.error(f"Scrape failed: {e}") | |
| st.session_state["last_result"] = None | |
| # ------------------------------ | |
| # Display last result (persistent across reruns) | |
| # ------------------------------ | |
| result = st.session_state.get("last_result") | |
| if result: | |
| st.subheader("Result") | |
| st.write(f"**Base site:** {result['site_base_url']}") | |
| st.write(f"**About URL:** {result['about_url']}") | |
| st.write(f"**Page title:** {result['title']}") | |
| st.write(f"**Headshot URL:** {result['author_image_url']}") | |
| st.write(f"**Saved file:** {result['local_path']}") | |
| local_path = result.get("local_path") | |
| if local_path: | |
| st.image(local_path, caption="Detected headshot", width=350) | |
| # Download button – this will trigger a rerun, | |
| # but the result is preserved in st.session_state | |
| try: | |
| with open(local_path, "rb") as f: | |
| img_bytes = f.read() | |
| st.download_button( | |
| "⬇️ Download Image", | |
| data=img_bytes, | |
| file_name=os.path.basename(local_path), | |
| mime="image/jpeg", | |
| ) | |
| except Exception as e: | |
| st.warning(f"Could not prepare download: {e}") | |
| else: | |
| st.warning("No headshot found for this site.") | |
| # ------------------------------ | |
| # Catalog dropdown presets for GPT filters | |
| # ------------------------------ | |
| catalog_df = load_catalog_data() | |
| country_options = collect_unique_options( | |
| catalog_df, | |
| ["country", "Country", "region", "Region"], | |
| ) | |
| if "United States" not in country_options: | |
| country_options = ["United States"] + country_options | |
| vertical_options = collect_unique_options( | |
| catalog_df, | |
| ["vertical", "Vertical", "primary_vertical", "PrimaryVertical"], | |
| ) | |
| demographic_options = collect_unique_options( | |
| catalog_df, | |
| [ | |
| "demographic", | |
| "Demographic", | |
| "audience_demographic", | |
| "AudienceDemographic", | |
| "audience_region", | |
| "AudienceRegion", | |
| "gender", | |
| "Gender", | |
| ], | |
| split_chars=r"[;,]", | |
| ) | |
| format_options = collect_unique_options( | |
| catalog_df, | |
| ["format", "Format", "formats", "Formats", "formats_supported"], | |
| split_chars=r"[;,/]", | |
| ) | |
| if not format_options: | |
| format_options = ["IG reel", "Story", "Article", "Video"] | |
| platform_options = collect_unique_options( | |
| catalog_df, | |
| ["platform", "Platform", "platforms", "Platforms", "platforms_supported"], | |
| split_chars=r"[;,/]", | |
| ) | |
| platform_defaults = ["Instagram", "TikTok"] | |
| for default_platform in platform_defaults: | |
| if default_platform not in platform_options: | |
| platform_options.append(default_platform) | |
| platform_options = sorted(set(platform_options)) | |
| follower_tier_options = collect_unique_options( | |
| catalog_df, | |
| ["follower_tier", "FollowerTier", "tier", "Tier", "audience_tier"], | |
| split_chars=r"[;,]", | |
| ) | |
| if not follower_tier_options: | |
| follower_tier_options = ["Nano", "Micro", "Mid", "Macro", "Mega"] | |
| def summarize_filters(filters): | |
| """Create a structured summary to send to the GPT.""" | |
| lines = [ | |
| "Mandatory filters (fail any = exclude):", | |
| f"- Country: {filters['country']}", | |
| f"- Has IG account required: {filters['has_ig_account']}", | |
| f"- Interested in custom content: {filters['interested_in_custom_content']}", | |
| f"- Allow potential advertiser concern flag: {filters['allow_advertiser_concern']}", | |
| f"- Brand avoidance list must not include: {filters['brand_avoidance_brand'] or 'N/A'}", | |
| "User-selected campaign criteria:", | |
| f"- Vertical: {filters['vertical'] or 'Not specified'}", | |
| f"- Demographic: {filters['demographic'] or 'Not specified'}", | |
| f"- Required formats: {', '.join(filters['formats']) if filters['formats'] else 'Not specified'}", | |
| f"- Platform: {filters['platform']}", | |
| f"- Follower tier target: {filters['follower_tier'] or 'Not specified (use default tiers)'}", | |
| f"- Prioritize Creator Collaborative opt-in: {filters['prioritize_creator_collab']}", | |
| ] | |
| return "\n".join(lines) | |
| st.divider() | |
| st.header("Creator Catalog GPT") | |
| st.caption( | |
| "Chat with the custom GPT using your OpenAI credentials. " | |
| "Set REPO_SECRET_OPENAI_API_KEY (and optional OPENAI_BASE_URL, CUSTOM_GPT_MODEL, " | |
| "CUSTOM_GPT_INSTRUCTIONS) as secrets in the Hugging Face Space." | |
| ) | |
| st.subheader("Campaign filters") | |
| st.caption( | |
| "Standardize the inputs sent to the GPT using dropdowns populated from data.csv when available." | |
| ) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| selected_country = st.selectbox("Country", country_options, index=0) | |
| has_ig_account = st.checkbox("Require Instagram account", value=True) | |
| interested_custom = st.checkbox("Interested in custom content", value=True) | |
| allow_advertiser_concern = st.checkbox( | |
| "Allow creators with advertiser concern flag", value=False | |
| ) | |
| brand_avoidance = st.text_input( | |
| "Brand to avoid (will exclude creators flagged with this brand)", | |
| placeholder="Campaign brand name", | |
| ) | |
| with col2: | |
| vertical = st.selectbox( | |
| "Vertical", | |
| ( | |
| ["(Not specified)"] + vertical_options | |
| if vertical_options | |
| else ["(Not specified)"] | |
| ), | |
| ) | |
| demographic = st.selectbox( | |
| "Demographic focus", | |
| ( | |
| ["(Not specified)"] + demographic_options | |
| if demographic_options | |
| else ["(Not specified)"] | |
| ), | |
| ) | |
| format_selection = st.multiselect("Required formats", format_options) | |
| platform_default_index = ( | |
| platform_options.index("Instagram") if "Instagram" in platform_options else 0 | |
| ) | |
| platform = st.selectbox("Platform", platform_options, index=platform_default_index) | |
| follower_tier = st.selectbox( | |
| "Follower tier match (returns requested tier or one below)", | |
| ["(Not specified)"] + follower_tier_options, | |
| ) | |
| prioritize_creator_collab = st.checkbox( | |
| "Prioritize Creator Collaborative opt-in", value=True | |
| ) | |
| campaign_filters = { | |
| "country": selected_country, | |
| "has_ig_account": has_ig_account, | |
| "interested_in_custom_content": interested_custom, | |
| "allow_advertiser_concern": allow_advertiser_concern, | |
| "brand_avoidance_brand": brand_avoidance.strip(), | |
| "vertical": "" if vertical == "(Not specified)" else vertical, | |
| "demographic": "" if demographic == "(Not specified)" else demographic, | |
| "formats": format_selection, | |
| "platform": platform, | |
| "follower_tier": "" if follower_tier == "(Not specified)" else follower_tier, | |
| "prioritize_creator_collab": prioritize_creator_collab, | |
| } | |
| st.markdown("**Filter summary for GPT:**") | |
| st.code(summarize_filters(campaign_filters)) | |
| prompt = st.text_area( | |
| "Ask the GPT a question", | |
| key="gpt_prompt", | |
| placeholder="E.g., summarize the most recent scraping result", | |
| ) | |
| if st.button("Send to GPT"): | |
| if not prompt.strip(): | |
| st.error("Please enter a question or prompt for the GPT.") | |
| else: | |
| try: | |
| client = CustomGPT() | |
| filter_summary = summarize_filters(campaign_filters) | |
| full_prompt = ( | |
| f"{prompt.strip()}\n\n" | |
| "Use these campaign filter selections when applying the Creator Catalog instructions:\n" | |
| f"{filter_summary}\n" | |
| ) | |
| reply = client.run(full_prompt, history=st.session_state["chat_history"]) | |
| st.session_state["chat_history"].extend( | |
| [ | |
| {"role": "user", "content": full_prompt}, | |
| {"role": "assistant", "content": reply}, | |
| ] | |
| ) | |
| except Exception as e: | |
| st.error(f"GPT request failed: {e}") | |
| if st.session_state["chat_history"]: | |
| st.subheader("Conversation") | |
| for message in st.session_state["chat_history"]: | |
| prefix = "You" if message["role"] == "user" else "GPT" | |
| st.markdown(f"**{prefix}:** {message['content']}") | |