Spaces:
Sleeping
Sleeping
| # model_api.py | |
| import requests | |
| import re | |
| HF_MODEL = "mistralai/Mistral-7B-Instruct-v0.2" # You can change to any hosted HF model | |
| def query_hf_model(prompt: str, hf_token: str, max_tokens: int = 256): | |
| """ | |
| Calls the Hugging Face Inference API to generate SQL from a prompt. | |
| Returns the SQL query as a string. | |
| """ | |
| api_url = f"https://api-inference.huggingface.co/models/{HF_MODEL}" | |
| headers = {"Authorization": f"Bearer {hf_token}"} | |
| payload = { | |
| "inputs": prompt, | |
| "parameters": {"max_new_tokens": max_tokens, "temperature": 0.0}, | |
| "options": {"wait_for_model": True} | |
| } | |
| resp = requests.post(api_url, headers=headers, json=payload, timeout=60) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| # Extract generated text safely | |
| if isinstance(data, list) and len(data) > 0: | |
| item = data[0] | |
| if isinstance(item, dict): | |
| text = item.get("generated_text") or item.get("text") or str(item) | |
| else: | |
| text = str(item) | |
| elif isinstance(data, dict): | |
| if "error" in data: | |
| raise RuntimeError(f"Model error: {data['error']}") | |
| text = data.get("generated_text") or data.get("text") or str(data) | |
| else: | |
| text = str(data) | |
| # Remove code fences if present | |
| text = re.sub(r"```.*?```", "", text, flags=re.S).strip() | |
| # Return only SELECT queries | |
| match = re.search(r"(?i)^\s*select\b.*", text, flags=re.S) | |
| if match: | |
| return match.group(0) | |
| else: | |
| return text.strip() | |