File size: 1,576 Bytes
3743009
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# model_api.py
import requests
import re

HF_MODEL = "mistralai/Mistral-7B-Instruct-v0.2"  # You can change to any hosted HF model

def query_hf_model(prompt: str, hf_token: str, max_tokens: int = 256):
    """

    Calls the Hugging Face Inference API to generate SQL from a prompt.

    Returns the SQL query as a string.

    """
    api_url = f"https://api-inference.huggingface.co/models/{HF_MODEL}"
    headers = {"Authorization": f"Bearer {hf_token}"}
    payload = {
        "inputs": prompt,
        "parameters": {"max_new_tokens": max_tokens, "temperature": 0.0},
        "options": {"wait_for_model": True}
    }

    resp = requests.post(api_url, headers=headers, json=payload, timeout=60)
    resp.raise_for_status()
    data = resp.json()

    # Extract generated text safely
    if isinstance(data, list) and len(data) > 0:
        item = data[0]
        if isinstance(item, dict):
            text = item.get("generated_text") or item.get("text") or str(item)
        else:
            text = str(item)
    elif isinstance(data, dict):
        if "error" in data:
            raise RuntimeError(f"Model error: {data['error']}")
        text = data.get("generated_text") or data.get("text") or str(data)
    else:
        text = str(data)

    # Remove code fences if present
    text = re.sub(r"```.*?```", "", text, flags=re.S).strip()

    # Return only SELECT queries
    match = re.search(r"(?i)^\s*select\b.*", text, flags=re.S)
    if match:
        return match.group(0)
    else:
        return text.strip()