Spaces:
Build error
Build error
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from utils.logger import setup_logger | |
| from utils.model_loader import ModelLoader | |
| from api.shopify_client import ShopifyClient | |
| logger = setup_logger(__name__) | |
| class SQLGenerator: | |
| def __init__(self): | |
| try: | |
| self.model_name = "premai-io/prem-1B-SQL" | |
| self.tokenizer = ModelLoader.load_model_with_retry( | |
| self.model_name, | |
| AutoTokenizer | |
| ) | |
| self.model = ModelLoader.load_model_with_retry( | |
| self.model_name, | |
| AutoModelForCausalLM | |
| ) | |
| self.shopify_client = ShopifyClient() | |
| except Exception as e: | |
| logger.error(f"Failed to initialize SQLGenerator: {str(e)}") | |
| raise | |
| def generate_query(self, natural_language_query): | |
| try: | |
| schema_info = """ | |
| CREATE TABLE products ( | |
| id DECIMAL(8,2) PRIMARY KEY, | |
| title VARCHAR(255), | |
| body_html VARCHAR(255), | |
| vendor VARCHAR(255), | |
| product_type VARCHAR(255), | |
| created_at VARCHAR(255), | |
| handle VARCHAR(255), | |
| updated_at DATE, | |
| published_at VARCHAR(255), | |
| template_suffix VARCHAR(255), | |
| published_scope VARCHAR(255), | |
| tags VARCHAR(255), | |
| status VARCHAR(255), | |
| admin_graphql_api_id DECIMAL(8,2), | |
| variants VARCHAR(255), | |
| options VARCHAR(255), | |
| images VARCHAR(255), | |
| image VARCHAR(255) | |
| ); | |
| """ | |
| prompt = f"""### Task: Generate a SQL query to answer the following question. | |
| ### Database Schema: {schema_info} | |
| ### Question: {natural_language_query} | |
| ### SQL Query:""" | |
| inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False) | |
| outputs = self.model.generate( | |
| inputs["input_ids"], | |
| max_length=256, | |
| do_sample=False, | |
| num_return_sequences=1, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| pad_token_id=self.tokenizer.pad_token_id | |
| ) | |
| return self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip() | |
| except Exception as e: | |
| logger.error(f"Query generation error: {str(e)}") | |
| return "Failed to generate SQL query due to an error." | |
| def fetch_shopify_data(self, endpoint): | |
| return self.shopify_client.fetch_data(endpoint) | |