Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import sqlite3 | |
| import uuid | |
| import pandas as pd | |
| import subprocess | |
| import threading | |
| import time | |
| from flask import Flask, jsonify, request, render_template | |
| app = Flask(__name__, template_folder="../templates", static_folder="../static") | |
| # Global state | |
| model_pipeline = None | |
| model_loading = False | |
| model_loading_error = None | |
| train_process = None | |
| train_log_path = "models/train.log" | |
| job_store = {} # {job_id: {"status": "running"|"done"|"error", "result": {...}}} | |
| def load_model_async(force=False): | |
| global model_pipeline, model_loading, model_loading_error | |
| try: | |
| model_loading = True | |
| model_loading_error = None | |
| if force: | |
| print("=== Thread: Reloading Model. Freeing memory... ===") | |
| model_pipeline = None | |
| import gc | |
| gc.collect() | |
| # torch is only present on the CUDA/dev build; the CPU Space serves via | |
| # llama.cpp and has no torch, so guard the VRAM cleanup. | |
| try: | |
| import torch | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| except ImportError: | |
| pass | |
| time.sleep(1) # Give time for garbage collector | |
| print("=== Thread: Loading Text-to-SQL Model ===") | |
| # Import inside thread to prevent server startup delay | |
| from inference import TextToSQLInference | |
| model_pipeline = TextToSQLInference() | |
| print("=== Thread: Model Loaded Successfully! ===") | |
| except Exception as e: | |
| model_loading_error = str(e) | |
| print(f"=== Thread Error Loading Model: {model_loading_error} ===") | |
| finally: | |
| model_loading = False | |
| def index(): | |
| return render_template("index.html") | |
| def model_status(): | |
| global model_pipeline, model_loading, model_loading_error | |
| if model_pipeline is not None: | |
| return jsonify({ | |
| "status": "loaded", | |
| "has_adapter": model_pipeline.has_adapter, | |
| "device": model_pipeline.device | |
| }) | |
| elif model_loading: | |
| return jsonify({"status": "loading"}) | |
| elif model_loading_error: | |
| return jsonify({"status": "error", "error": model_loading_error}) | |
| else: | |
| return jsonify({"status": "unloaded"}) | |
| def load_model(): | |
| global model_pipeline, model_loading | |
| data = request.json or {} | |
| force = data.get("force", False) | |
| if model_pipeline is not None and not force: | |
| return jsonify({"status": "already_loaded"}) | |
| if model_loading: | |
| return jsonify({"status": "loading"}) | |
| # Start loading in background thread | |
| thread = threading.Thread(target=load_model_async, args=(force,)) | |
| thread.daemon = True | |
| thread.start() | |
| return jsonify({"status": "started"}) | |
| def run_query(): | |
| global model_pipeline, job_store | |
| if model_pipeline is None: | |
| return jsonify({"success": False, "error": "Model is not loaded. Please wait for it to load."}), 400 | |
| data = request.json or {} | |
| question = data.get("question", "") | |
| use_adapter = data.get("use_adapter", True) | |
| if not question: | |
| return jsonify({"success": False, "error": "Question is required."}), 400 | |
| job_id = str(uuid.uuid4())[:8] | |
| job_store[job_id] = {"status": "running", "result": None} | |
| def run_job(): | |
| print(f"Job {job_id}: Question='{question}', UseAdapter={use_adapter}") | |
| result = model_pipeline.query_pipeline(question, use_adapter=use_adapter) | |
| job_store[job_id] = {"status": "done", "result": result} | |
| print(f"Job {job_id}: Complete") | |
| thread = threading.Thread(target=run_job) | |
| thread.daemon = True | |
| thread.start() | |
| return jsonify({"job_id": job_id, "status": "running"}) | |
| def query_result(job_id): | |
| job = job_store.get(job_id) | |
| if not job: | |
| return jsonify({"status": "not_found"}), 404 | |
| return jsonify(job) | |
| def start_train(): | |
| global train_process | |
| if train_process is not None and train_process.poll() is None: | |
| return jsonify({"status": "running", "message": "Training is already in progress."}) | |
| # Prepare directory | |
| os.makedirs("models", exist_ok=True) | |
| # Empty log | |
| with open(train_log_path, "w", encoding="utf-8") as f: | |
| f.write("=== Fine-Tuning Process Initiated ===\n") | |
| try: | |
| # Launch python src/train.py as a separate subprocess | |
| cmd = [sys.executable, "src/train.py"] | |
| print(f"Launching training process: {' '.join(cmd)}") | |
| # We write outputs to train.log | |
| log_file = open(train_log_path, "a", encoding="utf-8") | |
| # Ensure UTF-8 mode on Windows for loading libraries correctly | |
| env = os.environ.copy() | |
| env["PYTHONUTF8"] = "1" | |
| train_process = subprocess.Popen( | |
| cmd, | |
| stdout=log_file, | |
| stderr=subprocess.STDOUT, | |
| cwd=os.getcwd(), | |
| text=True, | |
| env=env | |
| ) | |
| return jsonify({"status": "started", "message": "Training process launched successfully."}) | |
| except Exception as e: | |
| return jsonify({"status": "failed", "error": str(e)}), 500 | |
| def train_status(): | |
| global train_process | |
| # Read log file | |
| log_content = "" | |
| if os.path.exists(train_log_path): | |
| try: | |
| with open(train_log_path, "r", encoding="utf-8") as f: | |
| # Read last 100 lines to keep request lightweight | |
| lines = f.readlines() | |
| log_content = "".join(lines[-100:]) | |
| except Exception as e: | |
| log_content = f"Error reading logs: {str(e)}" | |
| if train_process is None: | |
| return jsonify({"status": "idle", "logs": log_content}) | |
| exit_code = train_process.poll() | |
| if exit_code is None: | |
| return jsonify({"status": "running", "logs": log_content}) | |
| elif exit_code == 0: | |
| # Check if adapter directory was created to confirm success | |
| adapter_exists = os.path.exists("models/phi3-text-to-sql-adapter") | |
| return jsonify({ | |
| "status": "completed", | |
| "exit_code": exit_code, | |
| "logs": log_content, | |
| "adapter_created": adapter_exists | |
| }) | |
| else: | |
| return jsonify({"status": "failed", "exit_code": exit_code, "logs": log_content}) | |
| def get_schema(): | |
| # Return structured schema details to display in a beautiful sidebar | |
| schema = { | |
| "departments": [ | |
| {"name": "id", "type": "INTEGER", "key": "PRIMARY KEY"}, | |
| {"name": "name", "type": "TEXT", "key": "UNIQUE"}, | |
| {"name": "manager_id", "type": "INTEGER", "key": "FOREIGN KEY (employees.id)"} | |
| ], | |
| "employees": [ | |
| {"name": "id", "type": "INTEGER", "key": "PRIMARY KEY"}, | |
| {"name": "name", "type": "TEXT", "key": ""}, | |
| {"name": "department_id", "type": "INTEGER", "key": "FOREIGN KEY (departments.id)"}, | |
| {"name": "salary", "type": "INTEGER", "key": ""}, | |
| {"name": "hire_date", "type": "TEXT", "key": ""}, | |
| {"name": "manager_id", "type": "INTEGER", "key": "FOREIGN KEY (employees.id)"} | |
| ], | |
| "products": [ | |
| {"name": "id", "type": "INTEGER", "key": "PRIMARY KEY"}, | |
| {"name": "name", "type": "TEXT", "key": "UNIQUE"}, | |
| {"name": "category", "type": "TEXT", "key": ""}, | |
| {"name": "price", "type": "REAL", "key": ""} | |
| ], | |
| "sales": [ | |
| {"name": "id", "type": "INTEGER", "key": "PRIMARY KEY"}, | |
| {"name": "employee_id", "type": "INTEGER", "key": "FOREIGN KEY (employees.id)"}, | |
| {"name": "product_id", "type": "INTEGER", "key": "FOREIGN KEY (products.id)"}, | |
| {"name": "amount", "type": "REAL", "key": ""}, | |
| {"name": "quantity", "type": "INTEGER", "key": ""}, | |
| {"name": "sale_date", "type": "TEXT", "key": ""} | |
| ] | |
| } | |
| return jsonify(schema) | |
| if __name__ == "__main__": | |
| # Ensure required directories exist | |
| os.makedirs("templates", exist_ok=True) | |
| os.makedirs("static", exist_ok=True) | |
| os.makedirs("models", exist_ok=True) | |
| os.makedirs("data", exist_ok=True) | |
| # Cold-start: regenerate SQLite DB if missing (e.g. first run in HF Space) | |
| if not os.path.exists("data/company_sales.db"): | |
| print("=== Database not found. Running setup... ===") | |
| try: | |
| sys.path.insert(0, os.path.dirname(__file__)) | |
| from database import setup_database | |
| setup_database() | |
| print("=== Database ready. ===") | |
| except Exception as e: | |
| print(f"=== Warning: Could not create database: {e} ===") | |
| # Auto-load model on startup so users don't have to click "Load Model" | |
| print("=== Auto-starting model load in background thread ===") | |
| startup_thread = threading.Thread(target=load_model_async, args=(False,)) | |
| startup_thread.daemon = True | |
| startup_thread.start() | |
| # HF Spaces requires port 7860 and host 0.0.0.0 | |
| port = int(os.environ.get("PORT", 7860)) | |
| app.run(host="0.0.0.0", port=port, debug=False) | |