Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Gradio: **SQL compare** — fine-tuned Qwen SQL demo model vs Hub base (Transformers). | |
| No smolagents tab (compare only). | |
| Env (see ``sql_compare_ui_qwen/.env.example`` and README): ``QWEN_COMPARE_*`` for UI; | |
| repo ``.env`` for ``HF_TOKEN`` and shared project settings. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| os.environ["HF_DEACTIVATE_ASYNC_LOAD"] = "1" | |
| if os.environ.get("QWEN_COMPARE_SHOW_RESOURCE_TRACKER_WARNINGS", "").strip().lower() != "true": | |
| _pw = os.environ.get("PYTHONWARNINGS", "").strip() | |
| _rt = "ignore:resource_tracker:UserWarning" | |
| os.environ["PYTHONWARNINGS"] = f"{_pw},{_rt}" if _pw else _rt | |
| import gc | |
| import csv | |
| import html | |
| import io | |
| import re | |
| import socket | |
| import sqlite3 | |
| import sys | |
| import warnings | |
| from pathlib import Path | |
| ROOT = Path(__file__).resolve().parent | |
| REPO_ROOT = ROOT.parent | |
| if str(REPO_ROOT) not in sys.path: | |
| sys.path.insert(0, str(REPO_ROOT)) | |
| try: | |
| from dotenv import load_dotenv | |
| for env_path in (ROOT / ".env", REPO_ROOT / ".env"): | |
| if env_path.is_file(): | |
| load_dotenv(env_path) | |
| except ImportError: | |
| pass | |
| def _install_resource_tracker_warning_silencer() -> None: | |
| if os.environ.get("QWEN_COMPARE_SHOW_RESOURCE_TRACKER_WARNINGS", "").strip().lower() == "true": | |
| return | |
| warnings.filterwarnings( | |
| "ignore", | |
| message=r".*resource_tracker:.*[Ll]eaked.*semaphore.*", | |
| category=UserWarning, | |
| ) | |
| _orig = warnings.showwarning | |
| def _showwarning(message, category, filename, lineno, file=None, line=None): | |
| try: | |
| text = str(message) | |
| except Exception: | |
| text = "" | |
| if ( | |
| "resource_tracker" in text | |
| and "leaked" in text | |
| and "semaphore" in text | |
| and "clean up at shutdown" in text | |
| ): | |
| return | |
| _orig(message, category, filename, lineno, file=file, line=line) | |
| warnings.showwarning = _showwarning # type: ignore[assignment] | |
| _install_resource_tracker_warning_silencer() | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoModelForImageTextToText, AutoTokenizer | |
| try: | |
| from sql_compare_ui_qwen.prompting import build_prompt | |
| except ModuleNotFoundError: | |
| from prompting import build_prompt | |
| _hf_model = None | |
| _hf_tokenizer = None | |
| _hf_model_id: str | None = None | |
| _ft_hf_model = None | |
| _ft_hf_tokenizer = None | |
| _ft_hf_model_id: str | None = None | |
| FINETUNED_HUB_MODEL_ID = "Tuana/qwen35-08b-text2sql" | |
| BASE_MODEL_ID = "Qwen/Qwen3.5-0.8B" | |
| DEMO_QUESTION_EXAMPLES: tuple[str, ...] = ( | |
| "Count how many management rows exist per temporary_acting value", | |
| "List all department names.", | |
| "Count how many management rows exist per department.", | |
| "Which departments were created before the year 2000?", | |
| "For each department, show the department name and the name of its head.", | |
| "List the names of heads who were born in Alabama.", | |
| "Which heads are temporary acting in their management role?", | |
| "How many departments are there?", | |
| ) | |
| def _env(name: str, default: str = "") -> str: | |
| v = os.environ.get(name) | |
| if v is None or str(v).strip() == "": | |
| return default | |
| return str(v).strip() | |
| def _hf_token() -> str | None: | |
| t = (_env("QWEN_COMPARE_HF_TOKEN") or _env("HF_TOKEN", "")).strip() | |
| return t or None | |
| def _demo_data_dir() -> Path: | |
| for path in ( | |
| ROOT / "data" / "spider_eval_synthetic", | |
| REPO_ROOT / "data" / "spider_eval_synthetic", | |
| ): | |
| if (path / "department.csv").is_file(): | |
| return path | |
| return ROOT / "data" / "spider_eval_synthetic" | |
| def _first_free_port(host: str, start: int, *, max_tries: int = 40) -> int: | |
| for p in range(start, start + max_tries): | |
| with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | |
| try: | |
| s.bind((host, p)) | |
| return p | |
| except OSError: | |
| continue | |
| raise RuntimeError(f"No free TCP port in {start}..{start + max_tries - 1} on {host!r}") | |
| def _mps_is_available() -> bool: | |
| b = getattr(torch.backends, "mps", None) | |
| return b is not None and b.is_available() | |
| def _mps_load_dtype() -> torch.dtype: | |
| raw = _env("QWEN_COMPARE_MPS_DTYPE").lower() | |
| if raw in ("bf16", "bfloat16"): | |
| return torch.bfloat16 | |
| if raw in ("fp16", "float16", "16"): | |
| return torch.float16 | |
| return torch.float32 | |
| def _model_load_spec() -> tuple[torch.dtype, str | None, str | None, str]: | |
| raw = ( | |
| _env("QWEN_COMPARE_HUB_DEVICE_MAP") | |
| or _env("QWEN_COMPARE_DEVICE_MAP") | |
| ).lower() | |
| if raw in ("none", "null", "cpu"): | |
| return torch.float32, None, "cpu", raw or "cpu" | |
| if raw == "mps": | |
| if _mps_is_available(): | |
| return _mps_load_dtype(), None, "mps", "mps" | |
| return torch.float32, None, "cpu", "mps_unavailable" | |
| if raw.startswith("cuda") or raw == "auto": | |
| if torch.cuda.is_available(): | |
| return torch.bfloat16, ("auto" if raw == "auto" else raw), None, raw | |
| if _mps_is_available(): | |
| return _mps_load_dtype(), None, "mps", f"{raw}_cuda_missing" | |
| return torch.float32, None, "cpu", f"{raw}_no_accel" | |
| if raw: | |
| if torch.cuda.is_available(): | |
| return torch.bfloat16, raw, None, raw | |
| if _mps_is_available(): | |
| return _mps_load_dtype(), None, "mps", f"{raw}_mps_fallback" | |
| return torch.float32, None, "cpu", f"{raw}_cpu_fallback" | |
| if torch.cuda.is_available(): | |
| return torch.bfloat16, "auto", None, "cuda_auto" | |
| if _mps_is_available(): | |
| return _mps_load_dtype(), None, "mps", "mps_default" | |
| return torch.float32, None, "cpu", "cpu_default" | |
| def _log_model_device(kind: str, model: torch.nn.Module, reason: str, dtype: torch.dtype, device_map: str | None, to_device: str | None) -> None: | |
| p = next(model.parameters()) | |
| print( | |
| f"QWEN_DEVICE {kind}: reason={reason} | param_device={p.device} | " | |
| f"param_dtype={p.dtype} | load_dtype={dtype} | device_map={device_map!r} | " | |
| f"post_to={to_device!r}", | |
| flush=True, | |
| ) | |
| def unload_hf_model() -> None: | |
| global _hf_model, _hf_tokenizer, _hf_model_id | |
| _hf_model = None | |
| _hf_tokenizer = None | |
| _hf_model_id = None | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available(): | |
| try: | |
| torch.mps.empty_cache() | |
| except Exception: | |
| pass | |
| def unload_ft_hf_model() -> None: | |
| global _ft_hf_model, _ft_hf_tokenizer, _ft_hf_model_id | |
| _ft_hf_model = None | |
| _ft_hf_tokenizer = None | |
| _ft_hf_model_id = None | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available(): | |
| try: | |
| torch.mps.empty_cache() | |
| except Exception: | |
| pass | |
| def predict_hf(prompt: str) -> str: | |
| global _hf_model, _hf_tokenizer, _hf_model_id | |
| if _env("QWEN_COMPARE_SKIP_HUB").lower() == "true": | |
| return ( | |
| "Hub column skipped (`QWEN_COMPARE_SKIP_HUB=true`). Set `QWEN_COMPARE_SKIP_HUB=false` " | |
| "to load the Hub model again." | |
| ) | |
| mid = BASE_MODEL_ID | |
| token = _hf_token() | |
| max_new = int( | |
| _env("QWEN_COMPARE_MAX_NEW_TOKENS", _env("MAX_NEW_TOKENS", "512")) or "512" | |
| ) | |
| try: | |
| if _hf_model is None or _hf_model_id != mid: | |
| dtype, device_map, to_device, device_reason = _model_load_spec() | |
| tok_kw: dict = {"trust_remote_code": True, "use_fast": True} | |
| if token: | |
| tok_kw["token"] = token | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(mid, **tok_kw) | |
| except (AttributeError, TypeError) as e: | |
| err = str(e) | |
| if "'list' object has no attribute 'keys'" in err or "not a string" in err.lower(): | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| mid, **tok_kw, extra_special_tokens={} | |
| ) | |
| else: | |
| raise | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| kw: dict = { | |
| "trust_remote_code": True, | |
| "torch_dtype": dtype, | |
| "low_cpu_mem_usage": device_map is None, | |
| } | |
| if token: | |
| kw["token"] = token | |
| if device_map is not None: | |
| kw["device_map"] = device_map | |
| try: | |
| model = AutoModelForImageTextToText.from_pretrained(mid, **kw) | |
| except (OSError, ValueError, TypeError): | |
| model = AutoModelForCausalLM.from_pretrained(mid, **kw) | |
| if to_device: | |
| model = model.to(to_device) | |
| model.eval() | |
| _log_model_device("hub", model, device_reason, dtype, device_map, to_device) | |
| _hf_model, _hf_tokenizer, _hf_model_id = model, tokenizer, mid | |
| assert _hf_tokenizer is not None and _hf_model is not None | |
| messages = [{"role": "user", "content": prompt}] | |
| try: | |
| text = _hf_tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| enable_thinking=False, | |
| ) | |
| except TypeError: | |
| text = _hf_tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| inputs = _hf_tokenizer(text, return_tensors="pt") | |
| dev = next(_hf_model.parameters()).device | |
| inputs = {k: v.to(dev) for k, v in inputs.items()} | |
| with torch.inference_mode(): | |
| out = _hf_model.generate( | |
| **inputs, | |
| max_new_tokens=max_new, | |
| do_sample=False, | |
| pad_token_id=_hf_tokenizer.pad_token_id, | |
| eos_token_id=_hf_tokenizer.eos_token_id, | |
| ) | |
| in_len = inputs["input_ids"].shape[-1] | |
| gen_ids = out[0, in_len:] | |
| return _hf_tokenizer.decode(gen_ids, skip_special_tokens=True).strip() | |
| except Exception as ex: | |
| return f"Hub base: {ex!r}" | |
| def predict_finetuned_hf(prompt: str) -> str: | |
| global _ft_hf_model, _ft_hf_tokenizer, _ft_hf_model_id | |
| if _env("QWEN_COMPARE_SKIP_FINETUNED").lower() == "true": | |
| return ( | |
| "Fine-tuned column skipped (`QWEN_COMPARE_SKIP_FINETUNED=true`). " | |
| "Set `QWEN_COMPARE_SKIP_FINETUNED=false` to load it again." | |
| ) | |
| mid = FINETUNED_HUB_MODEL_ID | |
| token = _hf_token() | |
| max_new = int( | |
| _env("QWEN_COMPARE_MAX_NEW_TOKENS", _env("MAX_NEW_TOKENS", "512")) or "512" | |
| ) | |
| try: | |
| if _ft_hf_model is None or _ft_hf_model_id != mid: | |
| dtype, device_map, to_device, device_reason = _model_load_spec() | |
| tok_kw: dict = {"trust_remote_code": True, "use_fast": True} | |
| if token: | |
| tok_kw["token"] = token | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(mid, **tok_kw) | |
| except (AttributeError, TypeError) as e: | |
| err = str(e) | |
| if "'list' object has no attribute 'keys'" in err or "not a string" in err.lower(): | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| mid, **tok_kw, extra_special_tokens={} | |
| ) | |
| else: | |
| raise | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| kw: dict = { | |
| "trust_remote_code": True, | |
| "torch_dtype": dtype, | |
| "low_cpu_mem_usage": device_map is None, | |
| } | |
| if token: | |
| kw["token"] = token | |
| if device_map is not None: | |
| kw["device_map"] = device_map | |
| try: | |
| model = AutoModelForImageTextToText.from_pretrained(mid, **kw) | |
| except (OSError, ValueError, TypeError): | |
| model = AutoModelForCausalLM.from_pretrained(mid, **kw) | |
| if to_device: | |
| model = model.to(to_device) | |
| model.eval() | |
| _log_model_device("fine-tuned-hf", model, device_reason, dtype, device_map, to_device) | |
| _ft_hf_model, _ft_hf_tokenizer, _ft_hf_model_id = model, tokenizer, mid | |
| assert _ft_hf_tokenizer is not None and _ft_hf_model is not None | |
| messages = [{"role": "user", "content": prompt}] | |
| try: | |
| text = _ft_hf_tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| enable_thinking=False, | |
| ) | |
| except TypeError: | |
| text = _ft_hf_tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| inputs = _ft_hf_tokenizer(text, return_tensors="pt") | |
| dev = next(_ft_hf_model.parameters()).device | |
| inputs = {k: v.to(dev) for k, v in inputs.items()} | |
| with torch.inference_mode(): | |
| out = _ft_hf_model.generate( | |
| **inputs, | |
| max_new_tokens=max_new, | |
| do_sample=False, | |
| pad_token_id=_ft_hf_tokenizer.pad_token_id, | |
| eos_token_id=_ft_hf_tokenizer.eos_token_id, | |
| ) | |
| in_len = inputs["input_ids"].shape[-1] | |
| gen_ids = out[0, in_len:] | |
| return _ft_hf_tokenizer.decode(gen_ids, skip_special_tokens=True).strip() | |
| except Exception as ex: | |
| return f"Fine-tuned HF: {ex!r}" | |
| def _compare_sqlite_db_path() -> Path: | |
| raw = _env( | |
| "QWEN_COMPARE_DB_PATH", | |
| str(_demo_data_dir() / "synthetic.db"), | |
| ) | |
| return Path(raw).expanduser().resolve() | |
| def _load_csv_rows(data_dir: Path) -> tuple[list[dict[str, str]], list[dict[str, str]], list[dict[str, str]]]: | |
| with (data_dir / "department.csv").open(newline="", encoding="utf-8") as f: | |
| departments = list(csv.DictReader(f)) | |
| with (data_dir / "head.csv").open(newline="", encoding="utf-8") as f: | |
| heads = list(csv.DictReader(f)) | |
| with (data_dir / "management.csv").open(newline="", encoding="utf-8") as f: | |
| management = list(csv.DictReader(f)) | |
| return departments, heads, management | |
| def _ensure_compare_sqlite_db() -> Path: | |
| db = _compare_sqlite_db_path() | |
| if db.is_file(): | |
| return db | |
| data_dir = _demo_data_dir() | |
| departments, heads, management = _load_csv_rows(data_dir) | |
| db.parent.mkdir(parents=True, exist_ok=True) | |
| conn = sqlite3.connect(db) | |
| try: | |
| conn.executescript( | |
| """ | |
| DROP TABLE IF EXISTS department; | |
| DROP TABLE IF EXISTS management; | |
| DROP TABLE IF EXISTS head; | |
| CREATE TABLE department ( | |
| department_id VARCHAR, | |
| name VARCHAR, | |
| creation VARCHAR | |
| ); | |
| CREATE TABLE management ( | |
| department_id VARCHAR, | |
| head_id VARCHAR, | |
| temporary_acting VARCHAR | |
| ); | |
| CREATE TABLE head ( | |
| head_id VARCHAR, | |
| name VARCHAR, | |
| born_state VARCHAR | |
| ); | |
| """ | |
| ) | |
| conn.executemany( | |
| "INSERT INTO department (department_id, name, creation) VALUES (?, ?, ?)", | |
| [(r["department_id"], r["name"], r["creation"]) for r in departments], | |
| ) | |
| conn.executemany( | |
| "INSERT INTO head (head_id, name, born_state) VALUES (?, ?, ?)", | |
| [(r["head_id"], r["name"], r["born_state"]) for r in heads], | |
| ) | |
| conn.executemany( | |
| "INSERT INTO management (department_id, head_id, temporary_acting) VALUES (?, ?, ?)", | |
| [(r["department_id"], r["head_id"], r["temporary_acting"]) for r in management], | |
| ) | |
| conn.commit() | |
| finally: | |
| conn.close() | |
| return db | |
| def _database_preview_rows(limit: int = 5) -> list[dict[str, str]]: | |
| db = _ensure_compare_sqlite_db() | |
| if db.is_file(): | |
| conn = sqlite3.connect(f"file:{db}?mode=ro", uri=True) | |
| conn.row_factory = sqlite3.Row | |
| try: | |
| rows = conn.execute( | |
| """ | |
| SELECT | |
| d.department_id, | |
| d.name AS department, | |
| d.creation, | |
| h.name AS department_head, | |
| h.born_state, | |
| m.temporary_acting | |
| FROM department AS d | |
| JOIN management AS m ON m.department_id = d.department_id | |
| JOIN head AS h ON h.head_id = m.head_id | |
| ORDER BY d.department_id, h.head_id | |
| LIMIT ? | |
| """, | |
| (limit,), | |
| ).fetchall() | |
| return [dict(r) for r in rows] | |
| finally: | |
| conn.close() | |
| data_dir = _demo_data_dir() | |
| dept_rows, head_rows, management = _load_csv_rows(data_dir) | |
| departments = {r["department_id"]: r for r in dept_rows} | |
| heads = {r["head_id"]: r for r in head_rows} | |
| preview: list[dict[str, str]] = [] | |
| for rel in management: | |
| dept = departments.get(rel["department_id"]) | |
| head = heads.get(rel["head_id"]) | |
| if not dept or not head: | |
| continue | |
| preview.append( | |
| { | |
| "department_id": dept["department_id"], | |
| "department": dept["name"], | |
| "creation": dept["creation"], | |
| "department_head": head["name"], | |
| "born_state": head["born_state"], | |
| "temporary_acting": rel["temporary_acting"], | |
| } | |
| ) | |
| if len(preview) >= limit: | |
| break | |
| return preview | |
| def _database_preview_html() -> str: | |
| rows = _database_preview_rows() | |
| headers = [ | |
| "department_id", | |
| "department", | |
| "creation", | |
| "department_head", | |
| "born_state", | |
| "temporary_acting", | |
| ] | |
| body = "\n".join( | |
| "<tr>" | |
| + "".join(f"<td>{html.escape(str(row.get(h, '')))}</td>" for h in headers) | |
| + "</tr>" | |
| for row in rows | |
| ) | |
| header = "".join(f"<th>{html.escape(h)}</th>" for h in headers) | |
| return f""" | |
| <section class="db-preview"> | |
| <div> | |
| <p class="eyebrow">Dummy database preview</p> | |
| <h2>What the model is querying</h2> | |
| <p> | |
| The demo database has three related tables: | |
| <code>department</code>, <code>management</code>, and <code>head</code>. | |
| These five rows are real examples from the local synthetic database. | |
| </p> | |
| </div> | |
| <table> | |
| <thead><tr>{header}</tr></thead> | |
| <tbody>{body}</tbody> | |
| </table> | |
| </section> | |
| """ | |
| def _compare_validate_select(sql: str) -> tuple[bool, str]: | |
| s = sql.strip() | |
| if not s: | |
| return False, "empty SQL" | |
| parts = [p.strip() for p in s.split(";") if p.strip()] | |
| if len(parts) != 1: | |
| return False, "exactly one SQL statement (no multiple statements)" | |
| one = parts[0] | |
| low = one.lower() | |
| if not low.startswith("select") and not low.startswith("with"): | |
| return False, "only SELECT (or WITH … SELECT) queries are allowed" | |
| for b in ( | |
| "attach", | |
| "pragma", | |
| "delete", | |
| "insert", | |
| "update", | |
| "drop", | |
| "create", | |
| "alter", | |
| "replace", | |
| "truncate", | |
| "vacuum", | |
| "detach", | |
| ): | |
| if re.search(rf"\b{b}\b", low): | |
| return False, f"forbidden keyword: {b}" | |
| return True, one | |
| def _compare_format_rows(cols: list[str], rows: list[tuple[Any, ...]], *, limit: int) -> str: | |
| if not cols: | |
| return "(no columns)" | |
| buf = io.StringIO() | |
| buf.write(" | ".join(cols) + "\n") | |
| buf.write("-" * min(120, 8 * len(cols)) + "\n") | |
| for row in rows[:limit]: | |
| buf.write(" | ".join(str(x) if x is not None else "NULL" for x in row) + "\n") | |
| if len(rows) > limit: | |
| buf.write(f"\n… truncated to {limit} rows ({len(rows)} returned)\n") | |
| return buf.getvalue() | |
| def _last_select_statement(s: str) -> str: | |
| s = (s or "").strip() | |
| if not s: | |
| return "" | |
| anchors = [ | |
| m.start() | |
| for m in re.finditer(r"(?:^|\n)\s*\b(WITH|SELECT)\b", s, re.MULTILINE | re.IGNORECASE) | |
| ] | |
| if not anchors: | |
| return "" | |
| frag = s[anchors[-1] :].strip() | |
| if ";" in frag: | |
| primary = frag.split(";", 1)[0].strip() | |
| if re.match(r"(?is)^\s*(?:with|select)\b", primary): | |
| return primary.rstrip(";").strip() | |
| return frag.rstrip(";").strip() | |
| def _extract_sql(text: str) -> str: | |
| if not text or not str(text).strip(): | |
| return "" | |
| t = str(text).strip() | |
| if t.lower().startswith("no local checkpoint") or "skipped" in t.lower(): | |
| return "" | |
| blocks = re.findall(r"```(?:sql)?\s*([\s\S]*?)```", t, re.IGNORECASE) | |
| for raw in reversed(blocks): | |
| stmt = _last_select_statement(raw) | |
| if stmt: | |
| return stmt | |
| return _last_select_statement(t) | |
| def _execute_compare_sql(sql: str, *, row_limit: int = 150) -> str: | |
| if not (sql or "").strip(): | |
| return "(no SELECT / WITH extracted — nothing to run)" | |
| ok, stmt = _compare_validate_select(sql) | |
| if not ok: | |
| return f"Error: {stmt}" | |
| db = _ensure_compare_sqlite_db() | |
| try: | |
| conn = sqlite3.connect(f"file:{db}?mode=ro", uri=True) | |
| conn.row_factory = sqlite3.Row | |
| except sqlite3.Error as e: | |
| return f"Error opening database: {e!r}" | |
| try: | |
| cur = conn.cursor() | |
| cur.execute(stmt) | |
| rows = [tuple(r) for r in cur.fetchall()] | |
| cols = [d[0] for d in cur.description] if cur.description else [] | |
| return _compare_format_rows(list(cols), rows, limit=row_limit) | |
| except sqlite3.Error as e: | |
| return f"Error executing SQL: {e!r}" | |
| finally: | |
| conn.close() | |
| def run_compare(user_request: str): | |
| prompt = build_prompt(user_request) | |
| out_local = predict_finetuned_hf(prompt) | |
| if _env("QWEN_COMPARE_SEQUENTIAL_UNLOAD", "true").lower() == "true" and _env("QWEN_COMPARE_SKIP_FINETUNED").lower() != "true": | |
| unload_ft_hf_model() | |
| out_hf = predict_hf(prompt) | |
| if _env("QWEN_COMPARE_SEQUENTIAL_UNLOAD", "true").lower() == "true" and _env("QWEN_COMPARE_SKIP_HUB").lower() != "true": | |
| unload_hf_model() | |
| sql_local = _extract_sql(out_local) | |
| sql_hf = _extract_sql(out_hf) | |
| res_local = _execute_compare_sql(sql_local) | |
| res_hf = _execute_compare_sql(sql_hf) | |
| return out_local, res_local, out_hf, res_hf | |
| def main() -> None: | |
| hub = BASE_MODEL_ID | |
| fine_tuned_hub = FINETUNED_HUB_MODEL_ID | |
| title = "Small Text-to-SQL LLM Demo" | |
| hero = f""" | |
| <div class="hero"> | |
| <h1>{title}</h1> | |
| <p> | |
| Ask a natural-language question and compare how a small fine-tuned model performs | |
| against the untouched Hugging Face base model, <strong>{hub}</strong>. | |
| </p> | |
| <p> | |
| The fine-tuned model starts from <strong>{hub}</strong> and is trained for | |
| <strong>Text-to-SQL on your database</strong> with Vertex AI on Google Cloud, | |
| using Hugging Face PyTorch Deep Learning Containers. | |
| </p> | |
| <p> | |
| The app extracts each model's generated SQL, runs it against a read-only | |
| <strong>dummy SQLite database</strong>, and shows the query results side by side. | |
| </p> | |
| <p class="hero-meta"> | |
| Fine-tuned model: <b>{fine_tuned_hub}</b> | |
| Training container family: <b>Hugging Face PyTorch Training DLC</b> | |
| </p> | |
| </div> | |
| """ | |
| theme = gr.themes.Monochrome( | |
| primary_hue="violet", | |
| secondary_hue="cyan", | |
| neutral_hue="slate", | |
| ).set( | |
| body_background_fill="#07111f", | |
| body_text_color="#e5edf8", | |
| block_background_fill="#0f1b2d", | |
| block_border_color="#23324a", | |
| block_label_background_fill="#17243a", | |
| block_label_text_color="#c7d2fe", | |
| button_primary_background_fill="#7c3aed", | |
| button_primary_background_fill_hover="#06b6d4", | |
| button_primary_text_color="#ffffff", | |
| input_background_fill="#0b1628", | |
| input_border_color="#2d3f5f", | |
| checkbox_label_background_fill="#0b1628", | |
| checkbox_label_background_fill_dark="#0b1628", | |
| checkbox_label_background_fill_hover="#152238", | |
| checkbox_label_background_fill_hover_dark="#152238", | |
| checkbox_label_background_fill_selected="#7c3aed", | |
| checkbox_label_background_fill_selected_dark="#7c3aed", | |
| checkbox_label_border_color="#2d3f5f", | |
| checkbox_label_border_color_dark="#2d3f5f", | |
| checkbox_label_border_color_hover="#3d5278", | |
| checkbox_label_border_color_hover_dark="#3d5278", | |
| checkbox_label_border_color_selected="#c4b5fd", | |
| checkbox_label_border_color_selected_dark="#c4b5fd", | |
| checkbox_label_text_color="#e5edf8", | |
| checkbox_label_text_color_dark="#e5edf8", | |
| checkbox_label_text_color_selected="#ffffff", | |
| checkbox_label_text_color_selected_dark="#ffffff", | |
| ) | |
| css = """ | |
| .gradio-container { | |
| background: | |
| radial-gradient(circle at top left, rgba(124, 58, 237, 0.24), transparent 28rem), | |
| radial-gradient(circle at top right, rgba(6, 182, 212, 0.18), transparent 24rem), | |
| #07111f; | |
| } | |
| .hero { | |
| padding: 1.2rem 1.4rem; | |
| border: 1px solid #25314a; | |
| border-radius: 18px; | |
| background: linear-gradient(135deg, rgba(15, 27, 45, 0.95), rgba(30, 41, 59, 0.72)); | |
| } | |
| .hero h1 { | |
| margin-bottom: 0.4rem; | |
| } | |
| .hero p { | |
| color: #dbeafe; | |
| font-size: 1.02rem; | |
| line-height: 1.55; | |
| margin: 0.45rem 0; | |
| } | |
| .hero code { | |
| color: #a5f3fc; | |
| background: rgba(8, 47, 73, 0.6); | |
| border-radius: 6px; | |
| padding: 0.12rem 0.3rem; | |
| } | |
| .hero-meta { | |
| color: #b6c7e3 !important; | |
| font-size: 0.92rem !important; | |
| } | |
| .db-preview { | |
| margin-top: 1rem; | |
| padding: 1.1rem 1.25rem; | |
| border: 1px solid #25314a; | |
| border-radius: 18px; | |
| background: rgba(11, 22, 40, 0.78); | |
| box-shadow: 0 18px 55px rgba(0, 0, 0, 0.22); | |
| } | |
| .db-preview .eyebrow { | |
| color: #67e8f9; | |
| font-size: 0.78rem; | |
| font-weight: 700; | |
| letter-spacing: 0.08em; | |
| margin: 0; | |
| text-transform: uppercase; | |
| } | |
| .db-preview h2 { | |
| color: #eef4ff; | |
| margin: 0.15rem 0 0.35rem; | |
| } | |
| .db-preview p { | |
| color: #cbd5e1; | |
| margin: 0 0 0.85rem; | |
| } | |
| .db-preview code { | |
| color: #a5f3fc; | |
| background: rgba(8, 47, 73, 0.65); | |
| border-radius: 6px; | |
| padding: 0.08rem 0.28rem; | |
| } | |
| .db-preview table { | |
| width: 100%; | |
| border-collapse: collapse; | |
| overflow: hidden; | |
| border-radius: 12px; | |
| font-size: 0.9rem; | |
| } | |
| .db-preview th, | |
| .db-preview td { | |
| border-bottom: 1px solid #23324a; | |
| padding: 0.62rem 0.7rem; | |
| text-align: left; | |
| } | |
| .db-preview th { | |
| color: #bfdbfe; | |
| background: rgba(30, 41, 59, 0.92); | |
| font-weight: 700; | |
| } | |
| .db-preview td { | |
| color: #e2e8f0; | |
| background: rgba(15, 23, 42, 0.62); | |
| } | |
| """ | |
| _default_question = DEMO_QUESTION_EXAMPLES[0] | |
| with gr.Blocks(title=title, theme=theme, css=css) as demo: | |
| gr.Markdown(hero) | |
| gr.HTML(_database_preview_html()) | |
| example_radio = gr.Radio( | |
| label="Example question", | |
| choices=list(DEMO_QUESTION_EXAMPLES), | |
| value=_default_question, | |
| ) | |
| inp = gr.Textbox( | |
| label="Ask the database", | |
| value=_default_question, | |
| placeholder="e.g. List all department names.", | |
| lines=4, | |
| ) | |
| example_radio.change(fn=lambda q: q, inputs=example_radio, outputs=inp) | |
| btn = gr.Button("Generate and compare SQL", variant="primary") | |
| with gr.Row(equal_height=False): | |
| with gr.Column(scale=1): | |
| gr.Markdown("#### Fine-tuned model from Hugging Face") | |
| out_local = gr.Textbox(label="Generated SQL / model output", lines=12) | |
| out_local_result = gr.Textbox(label="Dummy database result", lines=14) | |
| with gr.Column(scale=1): | |
| gr.Markdown("#### Hub base (Transformers)") | |
| out_hf = gr.Textbox(label="Generated SQL / model output", lines=12) | |
| out_hf_result = gr.Textbox(label="Dummy database result", lines=14) | |
| btn.click( | |
| fn=run_compare, | |
| inputs=[inp], | |
| outputs=[out_local, out_local_result, out_hf, out_hf_result], | |
| ) | |
| in_space = bool(os.environ.get("SPACE_ID")) | |
| host = _env("QWEN_COMPARE_GRADIO_HOST", "0.0.0.0" if in_space else "127.0.0.1") | |
| preferred = int(_env("QWEN_COMPARE_GRADIO_PORT", os.environ.get("PORT", "7860") if in_space else "7861")) | |
| port = preferred if in_space else _first_free_port(host, preferred) | |
| if not in_space and port != preferred: | |
| print(f"Port {preferred} busy; using {port}.", file=sys.stderr) | |
| demo.launch(server_name=host, server_port=port) | |
| if __name__ == "__main__": | |
| main() | |