nahidstaq's picture
Add fine-tuned Qwen3.5-2B distill-structure model with Gradio demo
88ab604 verified
"""Gradio demo for distill-structure model."""
import json
import re
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# ---------------------------------------------------------------------------
# Model
# ---------------------------------------------------------------------------
MODEL_ID = "nahidstaq/distill-structure"
SYSTEM = (
"You are an HTML structure analyzer. Given a compact DOM representation "
"of a web page (with headings removed), identify the logical sections. "
"Output a JSON array of sections, each with title, start_text, content_type, and assets fields."
)
_model = None
_tokenizer = None
def _load():
global _model, _tokenizer
if _model is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if device == "cuda" else torch.float32
_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, dtype=dtype, device_map="auto"
)
_model.eval()
return _model, _tokenizer
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _compact_dom(html: str) -> str:
from lxml import html as lxml_html
try:
doc = lxml_html.fromstring(html)
except Exception:
return html[:3000]
for tag in ("h1", "h2", "h3", "h4", "h5", "h6", "script", "style", "head"):
for el in doc.findall(f".//{tag}"):
p = el.getparent()
if p is not None:
p.remove(el)
def _walk(el, depth=0):
if not hasattr(el, "tag") or not isinstance(el.tag, str):
return ""
tag = el.tag
indent = " " * depth
if tag == "img":
alt = el.get("alt", "")
return f'{indent}<img alt="{alt}">' if alt else f'{indent}<img>'
if tag == "a":
text = (el.text_content() or "").strip()[:40]
href = (el.get("href") or "")[:60]
return f'{indent}<a href="{href}"> {text}'
if tag in ("td", "th"):
# Recurse into td if it has block children, otherwise truncate
children = [c for c in el if hasattr(c, "tag") and isinstance(c.tag, str)]
if children and depth < 8:
lines = [f"{indent}<{tag}>"]
for child in children:
r = _walk(child, depth + 1)
if r:
lines.append(r)
return "\n".join(lines)
text = (el.text_content() or "").strip()[:60]
return f"{indent}<{tag}> {text}" if text else ""
if depth > 7:
text = (el.text_content() or "").strip()[:80]
return f"{indent}[... {text}...]" if text else ""
text = (el.text or "").strip()[:50]
attrs = ""
for a in ("id", "class", "role"):
v = el.get(a)
if v:
attrs += f' {a}="{v[:30]}"'
line = f"{indent}<{tag}{attrs}>"
if text:
line += f" {text}"
lines = [line]
for child in el:
r = _walk(child, depth + 1)
if r:
lines.append(r)
return "\n".join(lines)
body = doc.find(".//body") or doc
result = _walk(body)
# Truncate to 4096 chars
if len(result) > 4096:
result = result[:4096] + "\n... (truncated)"
return result
def _extract_title(html: str) -> str:
m = re.search(r"<title>(.*?)</title>", html, re.I | re.S)
return m.group(1).strip() if m else "Untitled"
def _parse(raw: str) -> list[dict]:
try:
data = json.loads(raw)
if isinstance(data, list):
return data
except json.JSONDecodeError:
pass
m = re.search(r"\[.*?\]", raw, re.S)
if m:
try:
return json.loads(m.group())
except json.JSONDecodeError:
pass
return []
# ---------------------------------------------------------------------------
# Inference
# ---------------------------------------------------------------------------
def analyze_html(html: str, page_title: str) -> tuple[str, str]:
if not html.strip():
return "Please paste some HTML.", ""
model, tokenizer = _load()
compact = _compact_dom(html)
title = page_title.strip() or _extract_title(html)
messages = [
{"role": "system", "content": SYSTEM},
{"role": "user", "content": f"Page: {title}\n\n{compact}"},
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
ids = model.generate(
**inputs,
max_new_tokens=512,
do_sample=False,
temperature=None,
top_p=None,
pad_token_id=tokenizer.eos_token_id,
)
raw = tokenizer.decode(ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip()
sections = _parse(raw)
# Pretty format sections as markdown table
if sections:
md = "| # | Type | Title | Start text |\n|---|---|---|---|\n"
for i, s in enumerate(sections, 1):
title_s = s.get("title", "")
ctype = s.get("content_type", "?")
start = (s.get("start_text") or "")[:50]
md += f"| {i} | `{ctype}` | {title_s} | {start} |\n"
else:
md = "_Could not parse sections from model output._"
return md, raw
def analyze_url(url: str) -> tuple[str, str, str]:
if not url.strip():
return "", "Please enter a URL.", ""
try:
import httpx
r = httpx.get(url, follow_redirects=True, timeout=10,
headers={"User-Agent": "Mozilla/5.0"})
html = r.text
title = _extract_title(html)
md, raw = analyze_html(html, title)
return html[:5000] + ("..." if len(html) > 5000 else ""), md, raw
except Exception as e:
return "", f"Error fetching URL: {e}", ""
# ---------------------------------------------------------------------------
# UI
# ---------------------------------------------------------------------------
EXAMPLE_HTML = """<!DOCTYPE html>
<html>
<head><title>Product Page</title></head>
<body>
<h1>Our Amazing Product</h1>
<p>Welcome to the best product you've ever seen.</p>
<h2>Features</h2>
<ul>
<li>Lightning fast</li>
<li>Easy to use</li>
<li>Affordable pricing</li>
</ul>
<h2>Pricing</h2>
<table>
<tr><th>Plan</th><th>Price</th></tr>
<tr><td>Starter</td><td>$9/mo</td></tr>
<tr><td>Pro</td><td>$29/mo</td></tr>
</table>
<h2>FAQ</h2>
<h3>Is there a free trial?</h3>
<p>Yes! 14 days free, no credit card required.</p>
</body>
</html>"""
with gr.Blocks(title="distill-structure", theme=gr.themes.Soft()) as demo:
gr.Markdown("# distill-structure\nHTML section analyzer — fine-tuned Qwen3.5-2B")
with gr.Tabs():
with gr.Tab("Paste HTML"):
with gr.Row():
with gr.Column():
html_input = gr.Textbox(
label="HTML",
placeholder="Paste HTML here...",
lines=15,
value=EXAMPLE_HTML,
)
title_input = gr.Textbox(label="Page title (optional)", placeholder="Auto-detected from <title>")
btn_html = gr.Button("Analyze", variant="primary")
with gr.Column():
sections_out = gr.Markdown(label="Sections")
raw_out = gr.Textbox(label="Raw JSON output", lines=10)
btn_html.click(analyze_html, inputs=[html_input, title_input], outputs=[sections_out, raw_out])
with gr.Tab("From URL"):
with gr.Row():
with gr.Column():
url_input = gr.Textbox(label="URL", placeholder="https://news.ycombinator.com")
btn_url = gr.Button("Fetch & Analyze", variant="primary")
html_preview = gr.Textbox(label="Fetched HTML (preview)", lines=8)
with gr.Column():
sections_out2 = gr.Markdown(label="Sections")
raw_out2 = gr.Textbox(label="Raw JSON output", lines=10)
btn_url.click(analyze_url, inputs=[url_input], outputs=[html_preview, sections_out2, raw_out2])
gr.Markdown("""
---
**Model**: [nahidstaq/distill-structure](https://huggingface.co/nahidstaq/distill-structure) ·
**Base**: Qwen3.5-2B ·
**Task**: HTML structure analysis
""")
if __name__ == "__main__":
demo.launch()