Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import json | |
| import requests | |
| import os | |
| import subprocess | |
| import wget | |
| from loguru import logger | |
| from data_utils.line_based_parsing import parse_line_based_query, convert_to_lines | |
| from data_utils.base_conversion_utils import ( | |
| build_schema_maps, | |
| convert_modified_to_actual_code_string | |
| ) | |
| from data_utils.schema_utils import schema_to_line_based | |
| from configs.prompt_config import SYSTEM_PROMPT_V3, MODEL_PROMPT_V3 | |
| LLAMA_SERVER_URL = "http://127.0.0.1:8080/v1/chat/completions" | |
| MODEL_PATH = "./models/unsloth.Q8_0.gguf" | |
| def download_model(): | |
| """Download the model if it doesn't exist""" | |
| os.makedirs("./models", exist_ok=True) | |
| if not os.path.exists(MODEL_PATH): | |
| logger.info("Downloading model weights...") | |
| wget.download( | |
| "https://huggingface.co/ByteMaster01/NL2SQL/resolve/main/unsloth.Q8_0.gguf", | |
| MODEL_PATH | |
| ) | |
| logger.info("\nModel download complete!") | |
| def start_llama_server(): | |
| """Start the llama.cpp server with the downloaded model""" | |
| try: | |
| logger.info("Starting llama.cpp server...") | |
| subprocess.Popen([ | |
| "python", "-m", "llama_cpp.server", | |
| "--model", MODEL_PATH, | |
| "--port", "8080" | |
| ]) | |
| logger.info("Server started successfully!") | |
| except Exception as e: | |
| logger.error(f"Failed to start server: {e}") | |
| raise | |
| def convert_line_parsed_to_mongo(line_parsed: str, schema: dict) -> str: | |
| try: | |
| modified_query = parse_line_based_query(line_parsed) | |
| collection_name = schema["collections"][0]["name"] | |
| in2out, _ = build_schema_maps(schema) | |
| reconstructed_query = convert_modified_to_actual_code_string(modified_query, in2out, collection_name) | |
| return reconstructed_query | |
| except Exception as e: | |
| logger.error(f"Error converting line parsed to MongoDB query: {e}") | |
| return "" | |
| def process_query(schema_text: str, nl_query: str, additional_info: str = "") -> str: | |
| try: | |
| # Parse schema from string to dict | |
| schema = json.loads(schema_text) | |
| # Convert schema to line-based format | |
| line_based_schema = schema_to_line_based(schema) | |
| # Format prompt with line-based schema | |
| prompt = MODEL_PROMPT_V3.format( | |
| schema=line_based_schema, | |
| natural_language_query=nl_query, | |
| additional_info=additional_info | |
| ) | |
| # Prepare request payload | |
| payload = { | |
| "slot_id": 0, | |
| "temperature": 0.1, | |
| "n_keep": -1, | |
| "cache_prompt": True, | |
| "messages": [ | |
| { | |
| "role": "system", | |
| "content": SYSTEM_PROMPT_V3, | |
| }, | |
| { | |
| "role": "user", | |
| "content": prompt | |
| }, | |
| ] | |
| } | |
| # Make request to llama.cpp server | |
| response = requests.post(LLAMA_SERVER_URL, json=payload) | |
| response.raise_for_status() | |
| # Extract output from response | |
| output = response.json()["choices"][0]["message"]["content"].strip() | |
| logger.info(f"Model output: {output}") | |
| # Convert line-based output to MongoDB query | |
| mongo_query = convert_line_parsed_to_mongo(output, schema) | |
| return [ | |
| mongo_query, | |
| output | |
| ] | |
| except Exception as e: | |
| logger.error(f"Error processing query: {e}") | |
| error_msg = f"Error: {str(e)}" | |
| return [error_msg, error_msg, error_msg] | |
| def create_interface(): | |
| # Create Gradio interface | |
| iface = gr.Interface( | |
| fn=process_query, | |
| inputs=[ | |
| gr.Textbox( | |
| label="Schema (JSON format)", | |
| placeholder="Enter your MongoDB schema in JSON format...", | |
| lines=10 | |
| ), | |
| gr.Textbox( | |
| label="Natural Language Query", | |
| placeholder="Enter your query in natural language..." | |
| ), | |
| gr.Textbox( | |
| label="Additional Info (Optional)", | |
| placeholder="Enter any additional context (timestamps, etc)..." | |
| ), | |
| ], | |
| outputs=[ | |
| gr.Code(label="MongoDB Query", language="javascript", lines=1), | |
| gr.Textbox(label="Line-based Query") | |
| ], | |
| title="Natural Language to MongoDB Query Converter", | |
| description="Convert natural language queries to MongoDB queries based on your schema.", | |
| examples=[ | |
| [ | |
| '''{ | |
| "collections": [{ | |
| "name": "events", | |
| "document": { | |
| "properties": { | |
| "timestamp": {"bsonType": "int"}, | |
| "severity": {"bsonType": "int"}, | |
| "location": { | |
| "bsonType": "object", | |
| "properties": { | |
| "lat": {"bsonType": "double"}, | |
| "lon": {"bsonType": "double"} | |
| } | |
| } | |
| } | |
| } | |
| }]}''', | |
| "Find all events with severity greater than 5", | |
| "" | |
| ], | |
| [ | |
| '''{ | |
| "collections": [{ | |
| "name": "vehicles", | |
| "document": { | |
| "properties": { | |
| "timestamp": {"bsonType": "int"}, | |
| "vehicle_details": { | |
| "bsonType": "object", | |
| "properties": { | |
| "license_plate": {"bsonType": "string"}, | |
| "make": {"bsonType": "string"}, | |
| "model": {"bsonType": "string"}, | |
| "year": {"bsonType": "int"}, | |
| "color": {"bsonType": "string"} | |
| } | |
| }, | |
| "speed": {"bsonType": "double"}, | |
| "location": { | |
| "bsonType": "object", | |
| "properties": { | |
| "lat": {"bsonType": "double"}, | |
| "lon": {"bsonType": "double"} | |
| } | |
| } | |
| } | |
| } | |
| }]}''', | |
| "Find red Toyota vehicles manufactured after 2020 with speed above 60", | |
| "" | |
| ], | |
| [ | |
| '''{ | |
| "collections": [{ | |
| "name": "sensors", | |
| "document": { | |
| "properties": { | |
| "sensor_id": {"bsonType": "string"}, | |
| "readings": { | |
| "bsonType": "object", | |
| "properties": { | |
| "temperature": {"bsonType": "double"}, | |
| "humidity": {"bsonType": "double"}, | |
| "pressure": {"bsonType": "double"} | |
| } | |
| }, | |
| "timestamp": {"bsonType": "date"}, | |
| "status": {"bsonType": "string"} | |
| } | |
| } | |
| }]}''', | |
| "Find active sensors with temperature above 30 degrees in the last one day", | |
| '''current date is 21 january 2025''' | |
| ], | |
| [ | |
| '''{ | |
| "collections": [{ | |
| "name": "orders", | |
| "document": { | |
| "properties": { | |
| "order_id": {"bsonType": "string"}, | |
| "customer": { | |
| "bsonType": "object", | |
| "properties": { | |
| "id": {"bsonType": "string"}, | |
| "name": {"bsonType": "string"}, | |
| "email": {"bsonType": "string"} | |
| } | |
| }, | |
| "items": { | |
| "bsonType": "array", | |
| "items": { | |
| "bsonType": "object", | |
| "properties": { | |
| "product_id": {"bsonType": "string"}, | |
| "quantity": {"bsonType": "int"}, | |
| "price": {"bsonType": "double"} | |
| } | |
| } | |
| }, | |
| "total_amount": {"bsonType": "double"}, | |
| "status": {"bsonType": "string"}, | |
| "created_at": {"bsonType": "int"} | |
| } | |
| } | |
| }]}''', | |
| "Find orders with total amount greater than $100 that contain more than 3 items and were created in the last 24 hours", | |
| '''{"current_time": 1685890800, "last_24_hours": 1685804400}''' | |
| ] | |
| ], | |
| cache_examples=False, | |
| ) | |
| return iface | |
| if __name__ == "__main__": | |
| # Download the model | |
| download_model() | |
| # Start the llama.cpp server | |
| start_llama_server() | |
| # Give the server a moment to start | |
| import time | |
| time.sleep(5) | |
| # Launch the Gradio interface | |
| print("Starting Gradio interface...") | |
| iface = create_interface() | |
| iface.launch() |