Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import requests | |
| from config import ACCESS_TOKEN, SHOP_NAME | |
| class SQLGenerator: | |
| def __init__(self): | |
| self.model_name = "premai-io/prem-1B-SQL" | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
| self.model = AutoModelForCausalLM.from_pretrained(self.model_name) | |
| def generate_query(self, natural_language_query): | |
| schema_info = """ | |
| CREATE TABLE products ( | |
| id DECIMAL(8,2) PRIMARY KEY, | |
| title VARCHAR(255), | |
| body_html VARCHAR(255), | |
| vendor VARCHAR(255), | |
| product_type VARCHAR(255), | |
| created_at VARCHAR(255), | |
| handle VARCHAR(255), | |
| updated_at DATE, | |
| published_at VARCHAR(255), | |
| template_suffix VARCHAR(255), | |
| published_scope VARCHAR(255), | |
| tags VARCHAR(255), | |
| status VARCHAR(255), | |
| admin_graphql_api_id DECIMAL(8,2), | |
| variants VARCHAR(255), | |
| options VARCHAR(255), | |
| images VARCHAR(255), | |
| image VARCHAR(255) | |
| ); | |
| """ | |
| prompt = f"""### Task: Generate a SQL query to answer the following question. | |
| ### Database Schema: | |
| {schema_info} | |
| ### Question: {natural_language_query} | |
| ### SQL Query:""" | |
| inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(self.model.device) | |
| outputs = self.model.generate( | |
| inputs["input_ids"], | |
| max_length=256, | |
| do_sample=True, # Enable sampling to use temperature | |
| num_return_sequences=1, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| temperature=0.7, # Allow temperature to affect output | |
| top_k=50 # Consider top k predictions for variability | |
| ) | |
| generated_query = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip() | |
| return generated_query # Return the generated SQL query | |
| def fetch_shopify_data(self, endpoint): | |
| headers = { | |
| 'X-Shopify-Access-Token': ACCESS_TOKEN, | |
| 'Content-Type': 'application/json' | |
| } | |
| url = f"https://{SHOP_NAME}/admin/api/2023-10/{endpoint}.json" | |
| response = requests.get(url, headers=headers) | |
| if response.status_code == 200: | |
| return response.json() | |
| else: | |
| print(f"Error fetching {endpoint}: {response.status_code} - {response.text}") | |
| return None | |
| # Example of how to use the SQLGenerator class | |
| if __name__ == "__main__": | |
| sql_generator = SQLGenerator() | |
| # Example natural language query | |
| natural_language_query = "Show me shirts with red color" | |
| # Generate SQL query | |
| sql_query = sql_generator.generate_query(natural_language_query) | |
| print(f"Generated SQL Query: {sql_query}") | |
| # Fetch data from Shopify (example endpoint) | |
| shopify_data = sql_generator.fetch_shopify_data("products") | |
| print(f"Shopify Data: {shopify_data}") | |