"""Gradio demo for distill-structure model.""" import json import re import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer # --------------------------------------------------------------------------- # Model # --------------------------------------------------------------------------- MODEL_ID = "nahidstaq/distill-structure" SYSTEM = ( "You are an HTML structure analyzer. Given a compact DOM representation " "of a web page (with headings removed), identify the logical sections. " "Output a JSON array of sections, each with title, start_text, content_type, and assets fields." ) _model = None _tokenizer = None def _load(): global _model, _tokenizer if _model is None: device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.bfloat16 if device == "cuda" else torch.float32 _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) _model = AutoModelForCausalLM.from_pretrained( MODEL_ID, dtype=dtype, device_map="auto" ) _model.eval() return _model, _tokenizer # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _compact_dom(html: str) -> str: from lxml import html as lxml_html try: doc = lxml_html.fromstring(html) except Exception: return html[:3000] for tag in ("h1", "h2", "h3", "h4", "h5", "h6", "script", "style", "head"): for el in doc.findall(f".//{tag}"): p = el.getparent() if p is not None: p.remove(el) def _walk(el, depth=0): if not hasattr(el, "tag") or not isinstance(el.tag, str): return "" tag = el.tag indent = " " * depth if tag == "img": alt = el.get("alt", "") return f'{indent}{alt}' if alt else f'{indent}' if tag == "a": text = (el.text_content() or "").strip()[:40] href = (el.get("href") or "")[:60] return f'{indent} {text}' if tag in ("td", "th"): # Recurse into td if it has block children, otherwise truncate children = [c for c in el if hasattr(c, "tag") and isinstance(c.tag, str)] if children and depth < 8: lines = [f"{indent}<{tag}>"] for child in children: r = _walk(child, depth + 1) if r: lines.append(r) return "\n".join(lines) text = (el.text_content() or "").strip()[:60] return f"{indent}<{tag}> {text}" if text else "" if depth > 7: text = (el.text_content() or "").strip()[:80] return f"{indent}[... {text}...]" if text else "" text = (el.text or "").strip()[:50] attrs = "" for a in ("id", "class", "role"): v = el.get(a) if v: attrs += f' {a}="{v[:30]}"' line = f"{indent}<{tag}{attrs}>" if text: line += f" {text}" lines = [line] for child in el: r = _walk(child, depth + 1) if r: lines.append(r) return "\n".join(lines) body = doc.find(".//body") or doc result = _walk(body) # Truncate to 4096 chars if len(result) > 4096: result = result[:4096] + "\n... (truncated)" return result def _extract_title(html: str) -> str: m = re.search(r"(.*?)", html, re.I | re.S) return m.group(1).strip() if m else "Untitled" def _parse(raw: str) -> list[dict]: try: data = json.loads(raw) if isinstance(data, list): return data except json.JSONDecodeError: pass m = re.search(r"\[.*?\]", raw, re.S) if m: try: return json.loads(m.group()) except json.JSONDecodeError: pass return [] # --------------------------------------------------------------------------- # Inference # --------------------------------------------------------------------------- def analyze_html(html: str, page_title: str) -> tuple[str, str]: if not html.strip(): return "Please paste some HTML.", "" model, tokenizer = _load() compact = _compact_dom(html) title = page_title.strip() or _extract_title(html) messages = [ {"role": "system", "content": SYSTEM}, {"role": "user", "content": f"Page: {title}\n\n{compact}"}, ] prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = tokenizer(prompt, return_tensors="pt").to(model.device) with torch.no_grad(): ids = model.generate( **inputs, max_new_tokens=512, do_sample=False, temperature=None, top_p=None, pad_token_id=tokenizer.eos_token_id, ) raw = tokenizer.decode(ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip() sections = _parse(raw) # Pretty format sections as markdown table if sections: md = "| # | Type | Title | Start text |\n|---|---|---|---|\n" for i, s in enumerate(sections, 1): title_s = s.get("title", "") ctype = s.get("content_type", "?") start = (s.get("start_text") or "")[:50] md += f"| {i} | `{ctype}` | {title_s} | {start} |\n" else: md = "_Could not parse sections from model output._" return md, raw def analyze_url(url: str) -> tuple[str, str, str]: if not url.strip(): return "", "Please enter a URL.", "" try: import httpx r = httpx.get(url, follow_redirects=True, timeout=10, headers={"User-Agent": "Mozilla/5.0"}) html = r.text title = _extract_title(html) md, raw = analyze_html(html, title) return html[:5000] + ("..." if len(html) > 5000 else ""), md, raw except Exception as e: return "", f"Error fetching URL: {e}", "" # --------------------------------------------------------------------------- # UI # --------------------------------------------------------------------------- EXAMPLE_HTML = """ Product Page

Our Amazing Product

Welcome to the best product you've ever seen.

Features

Pricing

PlanPrice
Starter$9/mo
Pro$29/mo

FAQ

Is there a free trial?

Yes! 14 days free, no credit card required.

""" with gr.Blocks(title="distill-structure", theme=gr.themes.Soft()) as demo: gr.Markdown("# distill-structure\nHTML section analyzer — fine-tuned Qwen3.5-2B") with gr.Tabs(): with gr.Tab("Paste HTML"): with gr.Row(): with gr.Column(): html_input = gr.Textbox( label="HTML", placeholder="Paste HTML here...", lines=15, value=EXAMPLE_HTML, ) title_input = gr.Textbox(label="Page title (optional)", placeholder="Auto-detected from ") btn_html = gr.Button("Analyze", variant="primary") with gr.Column(): sections_out = gr.Markdown(label="Sections") raw_out = gr.Textbox(label="Raw JSON output", lines=10) btn_html.click(analyze_html, inputs=[html_input, title_input], outputs=[sections_out, raw_out]) with gr.Tab("From URL"): with gr.Row(): with gr.Column(): url_input = gr.Textbox(label="URL", placeholder="https://news.ycombinator.com") btn_url = gr.Button("Fetch & Analyze", variant="primary") html_preview = gr.Textbox(label="Fetched HTML (preview)", lines=8) with gr.Column(): sections_out2 = gr.Markdown(label="Sections") raw_out2 = gr.Textbox(label="Raw JSON output", lines=10) btn_url.click(analyze_url, inputs=[url_input], outputs=[html_preview, sections_out2, raw_out2]) gr.Markdown(""" --- **Model**: [nahidstaq/distill-structure](https://huggingface.co/nahidstaq/distill-structure) · **Base**: Qwen3.5-2B · **Task**: HTML structure analysis """) if __name__ == "__main__": demo.launch()