| """Gradio demo for distill-structure model."""
|
|
|
| import json
|
| import re
|
|
|
| import gradio as gr
|
| import torch
|
| from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
|
|
|
|
|
|
|
| MODEL_ID = "nahidstaq/distill-structure"
|
|
|
| SYSTEM = (
|
| "You are an HTML structure analyzer. Given a compact DOM representation "
|
| "of a web page (with headings removed), identify the logical sections. "
|
| "Output a JSON array of sections, each with title, start_text, content_type, and assets fields."
|
| )
|
|
|
| _model = None
|
| _tokenizer = None
|
|
|
|
|
| def _load():
|
| global _model, _tokenizer
|
| if _model is None:
|
| device = "cuda" if torch.cuda.is_available() else "cpu"
|
| dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
| _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
| _model = AutoModelForCausalLM.from_pretrained(
|
| MODEL_ID, dtype=dtype, device_map="auto"
|
| )
|
| _model.eval()
|
| return _model, _tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
| def _compact_dom(html: str) -> str:
|
| from lxml import html as lxml_html
|
|
|
| try:
|
| doc = lxml_html.fromstring(html)
|
| except Exception:
|
| return html[:3000]
|
|
|
| for tag in ("h1", "h2", "h3", "h4", "h5", "h6", "script", "style", "head"):
|
| for el in doc.findall(f".//{tag}"):
|
| p = el.getparent()
|
| if p is not None:
|
| p.remove(el)
|
|
|
| def _walk(el, depth=0):
|
| if not hasattr(el, "tag") or not isinstance(el.tag, str):
|
| return ""
|
| tag = el.tag
|
| indent = " " * depth
|
|
|
| if tag == "img":
|
| alt = el.get("alt", "")
|
| return f'{indent}<img alt="{alt}">' if alt else f'{indent}<img>'
|
|
|
| if tag == "a":
|
| text = (el.text_content() or "").strip()[:40]
|
| href = (el.get("href") or "")[:60]
|
| return f'{indent}<a href="{href}"> {text}'
|
|
|
| if tag in ("td", "th"):
|
|
|
| children = [c for c in el if hasattr(c, "tag") and isinstance(c.tag, str)]
|
| if children and depth < 8:
|
| lines = [f"{indent}<{tag}>"]
|
| for child in children:
|
| r = _walk(child, depth + 1)
|
| if r:
|
| lines.append(r)
|
| return "\n".join(lines)
|
| text = (el.text_content() or "").strip()[:60]
|
| return f"{indent}<{tag}> {text}" if text else ""
|
|
|
| if depth > 7:
|
| text = (el.text_content() or "").strip()[:80]
|
| return f"{indent}[... {text}...]" if text else ""
|
|
|
| text = (el.text or "").strip()[:50]
|
| attrs = ""
|
| for a in ("id", "class", "role"):
|
| v = el.get(a)
|
| if v:
|
| attrs += f' {a}="{v[:30]}"'
|
|
|
| line = f"{indent}<{tag}{attrs}>"
|
| if text:
|
| line += f" {text}"
|
| lines = [line]
|
| for child in el:
|
| r = _walk(child, depth + 1)
|
| if r:
|
| lines.append(r)
|
| return "\n".join(lines)
|
|
|
| body = doc.find(".//body") or doc
|
| result = _walk(body)
|
|
|
| if len(result) > 4096:
|
| result = result[:4096] + "\n... (truncated)"
|
| return result
|
|
|
|
|
| def _extract_title(html: str) -> str:
|
| m = re.search(r"<title>(.*?)</title>", html, re.I | re.S)
|
| return m.group(1).strip() if m else "Untitled"
|
|
|
|
|
| def _parse(raw: str) -> list[dict]:
|
| try:
|
| data = json.loads(raw)
|
| if isinstance(data, list):
|
| return data
|
| except json.JSONDecodeError:
|
| pass
|
| m = re.search(r"\[.*?\]", raw, re.S)
|
| if m:
|
| try:
|
| return json.loads(m.group())
|
| except json.JSONDecodeError:
|
| pass
|
| return []
|
|
|
|
|
|
|
|
|
|
|
|
|
| def analyze_html(html: str, page_title: str) -> tuple[str, str]:
|
| if not html.strip():
|
| return "Please paste some HTML.", ""
|
|
|
| model, tokenizer = _load()
|
|
|
| compact = _compact_dom(html)
|
| title = page_title.strip() or _extract_title(html)
|
|
|
| messages = [
|
| {"role": "system", "content": SYSTEM},
|
| {"role": "user", "content": f"Page: {title}\n\n{compact}"},
|
| ]
|
| prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
|
|
| with torch.no_grad():
|
| ids = model.generate(
|
| **inputs,
|
| max_new_tokens=512,
|
| do_sample=False,
|
| temperature=None,
|
| top_p=None,
|
| pad_token_id=tokenizer.eos_token_id,
|
| )
|
|
|
| raw = tokenizer.decode(ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip()
|
| sections = _parse(raw)
|
|
|
|
|
| if sections:
|
| md = "| # | Type | Title | Start text |\n|---|---|---|---|\n"
|
| for i, s in enumerate(sections, 1):
|
| title_s = s.get("title", "")
|
| ctype = s.get("content_type", "?")
|
| start = (s.get("start_text") or "")[:50]
|
| md += f"| {i} | `{ctype}` | {title_s} | {start} |\n"
|
| else:
|
| md = "_Could not parse sections from model output._"
|
|
|
| return md, raw
|
|
|
|
|
| def analyze_url(url: str) -> tuple[str, str, str]:
|
| if not url.strip():
|
| return "", "Please enter a URL.", ""
|
| try:
|
| import httpx
|
| r = httpx.get(url, follow_redirects=True, timeout=10,
|
| headers={"User-Agent": "Mozilla/5.0"})
|
| html = r.text
|
| title = _extract_title(html)
|
| md, raw = analyze_html(html, title)
|
| return html[:5000] + ("..." if len(html) > 5000 else ""), md, raw
|
| except Exception as e:
|
| return "", f"Error fetching URL: {e}", ""
|
|
|
|
|
|
|
|
|
|
|
|
|
| EXAMPLE_HTML = """<!DOCTYPE html>
|
| <html>
|
| <head><title>Product Page</title></head>
|
| <body>
|
| <h1>Our Amazing Product</h1>
|
| <p>Welcome to the best product you've ever seen.</p>
|
| <h2>Features</h2>
|
| <ul>
|
| <li>Lightning fast</li>
|
| <li>Easy to use</li>
|
| <li>Affordable pricing</li>
|
| </ul>
|
| <h2>Pricing</h2>
|
| <table>
|
| <tr><th>Plan</th><th>Price</th></tr>
|
| <tr><td>Starter</td><td>$9/mo</td></tr>
|
| <tr><td>Pro</td><td>$29/mo</td></tr>
|
| </table>
|
| <h2>FAQ</h2>
|
| <h3>Is there a free trial?</h3>
|
| <p>Yes! 14 days free, no credit card required.</p>
|
| </body>
|
| </html>"""
|
|
|
| with gr.Blocks(title="distill-structure", theme=gr.themes.Soft()) as demo:
|
| gr.Markdown("# distill-structure\nHTML section analyzer — fine-tuned Qwen3.5-2B")
|
|
|
| with gr.Tabs():
|
| with gr.Tab("Paste HTML"):
|
| with gr.Row():
|
| with gr.Column():
|
| html_input = gr.Textbox(
|
| label="HTML",
|
| placeholder="Paste HTML here...",
|
| lines=15,
|
| value=EXAMPLE_HTML,
|
| )
|
| title_input = gr.Textbox(label="Page title (optional)", placeholder="Auto-detected from <title>")
|
| btn_html = gr.Button("Analyze", variant="primary")
|
| with gr.Column():
|
| sections_out = gr.Markdown(label="Sections")
|
| raw_out = gr.Textbox(label="Raw JSON output", lines=10)
|
|
|
| btn_html.click(analyze_html, inputs=[html_input, title_input], outputs=[sections_out, raw_out])
|
|
|
| with gr.Tab("From URL"):
|
| with gr.Row():
|
| with gr.Column():
|
| url_input = gr.Textbox(label="URL", placeholder="https://news.ycombinator.com")
|
| btn_url = gr.Button("Fetch & Analyze", variant="primary")
|
| html_preview = gr.Textbox(label="Fetched HTML (preview)", lines=8)
|
| with gr.Column():
|
| sections_out2 = gr.Markdown(label="Sections")
|
| raw_out2 = gr.Textbox(label="Raw JSON output", lines=10)
|
|
|
| btn_url.click(analyze_url, inputs=[url_input], outputs=[html_preview, sections_out2, raw_out2])
|
|
|
| gr.Markdown("""
|
| ---
|
| **Model**: [nahidstaq/distill-structure](https://huggingface.co/nahidstaq/distill-structure) ·
|
| **Base**: Qwen3.5-2B ·
|
| **Task**: HTML structure analysis
|
| """)
|
|
|
| if __name__ == "__main__":
|
| demo.launch()
|
|
|