Spaces:
Sleeping
Sleeping
| import os | |
| import pyodbc | |
| from flask import Flask, request, jsonify | |
| from werkzeug.security import generate_password_hash, check_password_hash | |
| from flask_cors import CORS | |
| app = Flask(__name__) | |
| # ----------------------------------------------- | |
| # CORS: allow multiple origins via env (optional) | |
| # ----------------------------------------------- | |
| ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", "*") | |
| CORS(app, resources={r"/*": {"origins": [o.strip() for o in ALLOWED_ORIGINS.split(",")]}}) | |
| # ------------------------------------------------------ | |
| # MODE: "local" uses Windows Auth; "server" uses SQL Auth | |
| # Set MODE=server in Hugging Face Variables & secrets | |
| # ------------------------------------------------------ | |
| MODE = os.getenv("MODE", "local").lower() | |
| # --------------------------- | |
| # Local (Windows) connection | |
| # --------------------------- | |
| LOCAL_SQL_SERVER = os.getenv("LOCAL_SQL_SERVER", r"localhost\SQLEXPRESS") | |
| LOCAL_SQL_DATABASE = os.getenv("LOCAL_SQL_DATABASE", "PyDetect") | |
| LOCAL_SQL_DRIVER = os.getenv("LOCAL_SQL_DRIVER", "{ODBC Driver 17 for SQL Server}") | |
| # ---------------------------------------- | |
| # Remote (HF/AWS RDS) SQL Auth connection | |
| # ---------------------------------------- | |
| RDS_SQL_SERVER = os.getenv("RDS_SQL_SERVER", "") # e.g. mydb.abcxyz.ap-south-1.rds.amazonaws.com,1433 | |
| RDS_SQL_DATABASE = os.getenv("RDS_SQL_DATABASE", "PyDetect") | |
| RDS_SQL_USER = os.getenv("RDS_SQL_USER", "") | |
| RDS_SQL_PASSWORD = os.getenv("RDS_SQL_PASSWORD", "") | |
| RDS_SQL_DRIVER = os.getenv("RDS_SQL_DRIVER", "{ODBC Driver 18 for SQL Server}") | |
| RDS_ENCRYPT = os.getenv("RDS_ENCRYPT", "yes") # yes/no | |
| RDS_TRUST_CERT = os.getenv("RDS_TRUST_SERVER_CERT", "yes")# yes/no | |
| # ====================================================== | |
| # Establishing the database connection using env values | |
| # (CORE BEHAVIOR UNCHANGED for queries) | |
| # ====================================================== | |
| def get_db_connection(): | |
| if MODE == "local": | |
| # Windows Authentication (local) | |
| connection = pyodbc.connect( | |
| f"DRIVER={LOCAL_SQL_DRIVER};" | |
| f"SERVER={LOCAL_SQL_SERVER};" | |
| f"DATABASE={LOCAL_SQL_DATABASE};" | |
| f"Trusted_Connection=yes;" | |
| ) | |
| return connection | |
| else: | |
| # SQL Authentication (RDS / Hugging Face) | |
| connection = pyodbc.connect( | |
| f"DRIVER={RDS_SQL_DRIVER};" | |
| f"SERVER={RDS_SQL_SERVER};" | |
| f"DATABASE={RDS_SQL_DATABASE};" | |
| f"UID={RDS_SQL_USER};PWD={RDS_SQL_PASSWORD};" | |
| f"Encrypt={RDS_ENCRYPT};TrustServerCertificate={RDS_TRUST_CERT};" | |
| f"Connection Timeout=30;" | |
| ) | |
| return connection | |
| # ====================================================== | |
| # Create the User table only on local | |
| # (CORE CREATE SQL KEPT THE SAME) | |
| # ====================================================== | |
| def create_user_table(): | |
| conn = get_db_connection() | |
| cursor = conn.cursor() | |
| cursor.execute(''' | |
| IF NOT EXISTS (SELECT * FROM sysobjects WHERE name='Users' AND xtype='U') | |
| CREATE TABLE Users ( | |
| id INT IDENTITY(1,1) PRIMARY KEY, | |
| name NVARCHAR(120) NOT NULL, | |
| role NVARCHAR(50) NOT NULL, | |
| email NVARCHAR(120) UNIQUE NOT NULL, | |
| password NVARCHAR(255) NOT NULL | |
| ) | |
| ''') | |
| conn.commit() | |
| cursor.close() | |
| conn.close() | |
| # Initialize the table on startup ONLY IF local | |
| if MODE == "local": | |
| create_user_table() | |
| # =========================== | |
| # DO NOT CHANGE: API ROUTES | |
| # =========================== | |
| def sign_in(): | |
| data = request.json | |
| email = data.get('email') | |
| password = data.get('password') | |
| # Find user by email | |
| conn = get_db_connection() | |
| cursor = conn.cursor() | |
| cursor.execute('SELECT * FROM Users WHERE email = ?', (email,)) | |
| user = cursor.fetchone() | |
| if user: | |
| # Check if the provided password matches the hashed password stored in the database | |
| if check_password_hash(user[4], password): # user[4] is the password field | |
| return jsonify({"message": "Login successful"}), 200 | |
| else: | |
| return jsonify({"message": "Invalid email or password"}), 401 | |
| else: | |
| return jsonify({"message": "Email not found"}), 404 | |
| def sign_up(): | |
| data = request.json | |
| print("Received sign-up data:", data) # Log received data | |
| name = data.get('name') | |
| role = data.get('role') | |
| email = data.get('email') | |
| password = data.get('password') | |
| # Check if email is valid | |
| if not email or not password: | |
| return jsonify({"message": "Email and password are required"}), 400 | |
| # Check if the email already exists | |
| conn = get_db_connection() | |
| cursor = conn.cursor() | |
| cursor.execute('SELECT * FROM Users WHERE email = ?', (email,)) | |
| user = cursor.fetchone() | |
| if user: | |
| return jsonify({"message": "Email already in use"}), 400 | |
| # Hash the password before saving it | |
| hashed_password = generate_password_hash(password) | |
| # Insert the new user into the Users table | |
| cursor.execute('INSERT INTO Users (name, role, email, password) VALUES (?, ?, ?, ?)', (name, role, email, hashed_password)) | |
| conn.commit() | |
| cursor.close() | |
| conn.close() | |
| print("User created successfully:", name, email) # Log successful user creation | |
| return jsonify({"message": "User created successfully"}), 201 | |
| if __name__ == '__main__': | |
| # Default to 5000; if PORT is set (e.g., by Hugging Face), use it | |
| port = int(os.getenv("PORT", "5000")) | |
| app.run(host="0.0.0.0", port=port, debug=False) | |