Spaces:
Sleeping
Sleeping
File size: 1,576 Bytes
3743009 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
# model_api.py
import requests
import re
HF_MODEL = "mistralai/Mistral-7B-Instruct-v0.2" # You can change to any hosted HF model
def query_hf_model(prompt: str, hf_token: str, max_tokens: int = 256):
"""
Calls the Hugging Face Inference API to generate SQL from a prompt.
Returns the SQL query as a string.
"""
api_url = f"https://api-inference.huggingface.co/models/{HF_MODEL}"
headers = {"Authorization": f"Bearer {hf_token}"}
payload = {
"inputs": prompt,
"parameters": {"max_new_tokens": max_tokens, "temperature": 0.0},
"options": {"wait_for_model": True}
}
resp = requests.post(api_url, headers=headers, json=payload, timeout=60)
resp.raise_for_status()
data = resp.json()
# Extract generated text safely
if isinstance(data, list) and len(data) > 0:
item = data[0]
if isinstance(item, dict):
text = item.get("generated_text") or item.get("text") or str(item)
else:
text = str(item)
elif isinstance(data, dict):
if "error" in data:
raise RuntimeError(f"Model error: {data['error']}")
text = data.get("generated_text") or data.get("text") or str(data)
else:
text = str(data)
# Remove code fences if present
text = re.sub(r"```.*?```", "", text, flags=re.S).strip()
# Return only SELECT queries
match = re.search(r"(?i)^\s*select\b.*", text, flags=re.S)
if match:
return match.group(0)
else:
return text.strip()
|