File size: 3,436 Bytes
03f92a0
aba90dc
 
03f92a0
 
a796c9f
 
03f92a0
 
a796c9f
daac5be
d73b7bd
aba90dc
03f92a0
d73b7bd
daac5be
 
a796c9f
 
 
03f92a0
 
aba90dc
03f92a0
aba90dc
daac5be
 
aba90dc
daac5be
a796c9f
aba90dc
daac5be
aba90dc
03f92a0
 
 
 
 
 
 
daac5be
 
a796c9f
daac5be
03f92a0
 
 
 
 
 
 
daac5be
a796c9f
03f92a0
daac5be
d73b7bd
03f92a0
 
 
 
 
 
 
 
 
 
d73b7bd
 
 
03f92a0
a796c9f
aba90dc
03f92a0
 
aba90dc
 
 
03f92a0
daac5be
 
 
a796c9f
aba90dc
 
a796c9f
 
03f92a0
 
 
 
 
 
 
aba90dc
03f92a0
 
 
 
 
aba90dc
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os, time, json, sqlite3, textwrap, requests, sys
import gradio as gr

# ----------------- CONFIG -----------------
MODEL_ID = "gpt2"               # always public; swap later for sqlcoder
API_URL  = f"https://api-inference.huggingface.co/models/{MODEL_ID}"

HF_TOKEN = os.getenv("HF_TOKEN")
HEADERS  = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}

DB_PATH     = "company.db"
SCHEMA_FILE = "schema.sql"

# -------------- UTIL: DB ------------------
def create_db_if_needed():
    if os.path.exists(DB_PATH):
        return
    with open(SCHEMA_FILE) as f, sqlite3.connect(DB_PATH) as conn:
        conn.executescript(f.read())

# -------------- UTIL: CALL API ------------
def nlp_to_sql(question, schema_ddl):
    prompt = textwrap.dedent(f"""

        Translate the natural language question to a SQL query.



        ### Schema

        {schema_ddl}



        ### Question

        {question}



        ### SQL

    """)
    payload = {"inputs": prompt, "parameters": {"max_new_tokens": 64}}

    # ---------- DEBUG PRINTS ----------
    print("=" * 60, file=sys.stderr)
    print("DEBUG URL:", API_URL, file=sys.stderr)
    print("DEBUG Token present?:", bool(HF_TOKEN), file=sys.stderr)
    # ----------------------------------

    try:
        r = requests.post(API_URL, headers=HEADERS, json=payload, timeout=60)
    except Exception as e:
        return f"[ConnErr] {e}"

    # ---------- MORE DEBUG ----------
    print("DEBUG Status code:", r.status_code, file=sys.stderr)
    print("DEBUG Raw response (first 500 bytes):", r.text[:500], file=sys.stderr)
    print("=" * 60, file=sys.stderr)
    # ---------------------------------

    if r.status_code != 200:
        return f"[API {r.status_code}] {r.text[:200]}"

    try:
        out = r.json()
        generated = out[0].get("generated_text", "No generated_text")
    except Exception as e:
        return f"[JSONErr] {e}"

    return generated.split("### SQL")[-1].strip() or "[Empty SQL]"

# -------------- PIPELINE ------------------
def run(query):
    t0, trace = time.time(), []
    create_db_if_needed()

    with open(SCHEMA_FILE) as f:
        schema = f.read()
    trace.append(("Schema", "loaded"))

    sql = nlp_to_sql(query, schema)
    trace.append(("LLM", sql))

    try:
        with sqlite3.connect(DB_PATH) as conn:
            cur = conn.execute(sql)
            rows = cur.fetchall()
            cols = [d[0] for d in cur.description] if cur.description else []
        result = {"columns": cols, "rows": rows}
        trace.append(("Exec", f"{len(rows)} rows"))
    except Exception as e:
        result = {"error": str(e)}
        trace.append(("Exec error", str(e)))

    trace.append(("Time", f"{time.time()-t0:.2f}s"))
    return sql, json.dumps(result, indent=2), "\n".join(f"{s}: {m}" for s, m in trace)

# -------------- UI ------------------------
with gr.Blocks(title="Debug NLP→SQL") as demo:
    gr.Markdown("### Debugging Hugging Face Inference API calls")
    q = gr.Textbox(label="Ask", placeholder="Example: List employees")
    with gr.Row():
        sql_box = gr.Code(label="Generated SQL / debug output")
        res_box = gr.Code(label="Query result")
    tbox = gr.Textbox(label="Trace")
    btn = gr.Button("Run")
    btn.click(run, q, [sql_box, res_box, tbox])

if __name__ == "__main__":
    demo.launch()