Spaces:
Running
Running
| # --- load .env FIRST --- | |
| import os | |
| from dotenv import load_dotenv | |
| import requests | |
| from werkzeug.utils import secure_filename | |
| BASEDIR = os.path.abspath(os.path.dirname(__file__)) | |
| load_dotenv(os.path.join(BASEDIR, ".env")) # loads DB_USER, DB_PASSWORD, RUN_INIT_DB | |
| import socket | |
| import logging | |
| from threading import Lock | |
| from functools import wraps | |
| import datetime | |
| import bcrypt | |
| import jwt | |
| import pyodbc | |
| from flask import Flask, request, jsonify, make_response, current_app | |
| from flask_cors import CORS | |
| # ------------------------------------------------------------------------------ | |
| # App, ENV, CORS | |
| # ------------------------------------------------------------------------------ | |
| app = Flask(__name__) | |
| app.config['SECRET_KEY'] = '96c63da06374c1bde332516f3acbd23c84f35f90d8a6321a25d790a0a451af32' | |
| IS_PROD = os.getenv("ENV", "dev").lower() == "prod" | |
| _origins = os.getenv("ALLOWED_ORIGINS", "http://localhost:4200") | |
| ALLOWED_ORIGINS = [o.strip() for o in _origins.split(",") if o.strip()] | |
| # CORS(app, supports_credentials=True, origins=ALLOWED_ORIGINS) | |
| # Allow both localhost forms by default if env not set | |
| _default_origins = "http://localhost:4200,http://127.0.0.1:4200" | |
| _origins = os.getenv("ALLOWED_ORIGINS", _default_origins) | |
| ALLOWED_ORIGINS = [o.strip() for o in _origins.split(",") if o.strip()] | |
| CORS( | |
| app, | |
| resources={r"/*": {"origins": ALLOWED_ORIGINS}}, | |
| supports_credentials=True, | |
| allow_headers=["Content-Type", "Authorization", "X-Requested-With", "X-User"], | |
| expose_headers=["Set-Cookie"], | |
| methods=["GET", "POST", "OPTIONS"] | |
| ) | |
| def extract_username_from_request(req) -> str | None: | |
| # 1) Header | |
| hdr = req.headers.get("X-User") | |
| if hdr: | |
| return hdr | |
| # 2) Body | |
| data = req.get_json(silent=True) or {} | |
| if data.get("username"): | |
| return data.get("username") | |
| # 3) JWT cookie from verification.py | |
| token = req.cookies.get("access_token") | |
| if token: | |
| try: | |
| payload = jwt.decode(token, current_app.config["SECRET_KEY"], algorithms=["HS256"]) | |
| return payload.get("username") | |
| except jwt.ExpiredSignatureError: | |
| return None | |
| except jwt.InvalidTokenError: | |
| return None | |
| return None | |
| def add_cors_headers(resp): | |
| origin = request.headers.get("Origin") | |
| if origin and origin in ALLOWED_ORIGINS: | |
| # echo the origin, never '*', when using credentials | |
| resp.headers["Access-Control-Allow-Origin"] = origin | |
| resp.headers["Vary"] = "Origin" | |
| resp.headers["Access-Control-Allow-Credentials"] = "true" | |
| resp.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Requested-With, X-User" | |
| resp.headers["Access-Control-Allow-Methods"] = "GET, POST, OPTIONS" | |
| return resp | |
| def handle_options_early(): | |
| if request.method == "OPTIONS": | |
| resp = app.make_default_options_response() | |
| origin = request.headers.get("Origin") | |
| if origin and origin in ALLOWED_ORIGINS: | |
| resp.headers["Access-Control-Allow-Origin"] = origin | |
| resp.headers["Access-Control-Allow-Credentials"] = "true" | |
| # Mirror browser's requested headers/methods | |
| req_headers = request.headers.get("Access-Control-Request-Headers", "Content-Type, Authorization, X-Requested-With, X-User") | |
| req_method = request.headers.get("Access-Control-Request-Method", "POST") | |
| resp.headers["Access-Control-Allow-Headers"] = req_headers | |
| resp.headers["Access-Control-Allow-Methods"] = req_method | |
| return resp | |
| logging.basicConfig(level=logging.INFO) | |
| # NEW: API keys / shared config for blueprints (read from HF Secrets/ENV) | |
| app.config["COHERE_API_KEY"] = os.getenv("COHERE_API_KEY", "") | |
| # ------------------------------------------------------------------------------ | |
| # SQL Server configuration | |
| # ------------------------------------------------------------------------------ | |
| # DB_SERVER = "pykara-sqlserver.cb60o04yk948.ap-south-1.rds.amazonaws.com,1433" | |
| # DB_DATABASE = "AuthenticationDB1" | |
| DB_SERVER = os.getenv("DB_SERVER", r"(localdb)\MSSQLLocalDB") | |
| DB_DATABASE = os.getenv("DB_DATABASE", "AuthenticationDB1") | |
| DB_DRIVER = os.getenv("DB_DRIVER", "ODBC Driver 17 for SQL Server") # 17 in your image | |
| # Build connection string (FIXED) | |
| is_local = ( | |
| DB_SERVER.lower().startswith("localhost") | |
| or DB_SERVER.startswith(".") | |
| or DB_SERVER.lower().startswith("(localdb)") | |
| or "\\" in DB_SERVER | |
| ) | |
| if is_local: | |
| # Windows local / LocalDB using modern ODBC driver | |
| CONN_STR = ( | |
| f"DRIVER={{{DB_DRIVER}}};" | |
| f"SERVER={DB_SERVER};" | |
| f"DATABASE={DB_DATABASE};" | |
| "Trusted_Connection=yes;" | |
| "TrustServerCertificate=yes;" | |
| ) | |
| else: | |
| # Remote SQL auth | |
| CONN_STR = ( | |
| f"DRIVER={{{DB_DRIVER}}};" | |
| f"SERVER={DB_SERVER};DATABASE={DB_DATABASE};" | |
| f"UID={os.getenv('DB_USER')};PWD={os.getenv('DB_PASSWORD')};" | |
| "Encrypt=yes;TrustServerCertificate=yes;" | |
| ) | |
| # def get_db_connection(): | |
| # """Create a short-timeout connection. Fail clearly if secrets are missing.""" | |
| # if "Trusted_Connection=yes" not in CONN_STR: | |
| # if not os.getenv("DB_USER") or not os.getenv("DB_PASSWORD"): | |
| # raise RuntimeError("DB_USER/DB_PASSWORD are not set in the environment.") | |
| # return pyodbc.connect(CONN_STR, timeout=5) | |
| def get_db_connection(): | |
| """Create a short-timeout connection. Fail clearly if secrets are missing.""" | |
| if "Trusted_Connection=yes" not in CONN_STR: | |
| if not os.getenv("DB_USER") or not os.getenv("DB_PASSWORD"): | |
| raise RuntimeError("DB_USER/DB_PASSWORD are not set in the environment.") | |
| return pyodbc.connect(CONN_STR, timeout=5) | |
| def db_diag(): | |
| info = {} | |
| try: | |
| info["drivers_found"] = pyodbc.drivers() | |
| except Exception as e: | |
| info["drivers_found_error"] = str(e) | |
| # Resolve host part (before comma if "host,port") | |
| host = DB_SERVER.split(",")[0].strip() | |
| info["db_server_env"] = DB_SERVER | |
| info["db_database_env"] = DB_DATABASE | |
| info["db_driver_env"] = DB_DRIVER | |
| try: | |
| ip = socket.gethostbyname(host) | |
| info["dns_lookup"] = {"host": host, "ip": ip} | |
| except Exception as e: | |
| info["dns_lookup"] = {"host": host, "error": str(e)} | |
| try: | |
| conn = get_db_connection() | |
| conn.close() | |
| info["connect"] = "ok" | |
| except Exception as e: | |
| info["connect"] = f"error: {e}" | |
| return jsonify(info), 200 | |
| def init_db(): | |
| """Create tables if they do not exist.""" | |
| conn = get_db_connection() | |
| cur = conn.cursor() | |
| cur.execute(""" | |
| IF OBJECT_ID('Users', 'U') IS NULL | |
| CREATE TABLE Users ( | |
| id INT IDENTITY(1,1) PRIMARY KEY, | |
| username NVARCHAR(100) UNIQUE NOT NULL, | |
| password_hash NVARCHAR(500) NOT NULL, | |
| role NVARCHAR(50) DEFAULT 'user' | |
| ) | |
| """) | |
| cur.execute(""" | |
| IF OBJECT_ID('BlacklistedTokens', 'U') IS NULL | |
| CREATE TABLE BlacklistedTokens ( | |
| id INT IDENTITY(1,1) PRIMARY KEY, | |
| token NVARCHAR(1000) UNIQUE NOT NULL, | |
| created_at DATETIME DEFAULT GETDATE() | |
| ) | |
| """) | |
| cur.execute(""" | |
| IF OBJECT_ID('RefreshTokens', 'U') IS NULL | |
| CREATE TABLE RefreshTokens ( | |
| id INT IDENTITY(1,1) PRIMARY KEY, | |
| username NVARCHAR(100) NOT NULL, | |
| token NVARCHAR(1000) UNIQUE NOT NULL, | |
| created_at DATETIME DEFAULT GETDATE(), | |
| FOREIGN KEY (username) REFERENCES Users(username) ON DELETE CASCADE | |
| ) | |
| """) | |
| conn.commit() | |
| conn.close() | |
| # ------------------------------------------------------------------------------ | |
| # One-time DB initialisation (Flask 3.x safe) | |
| # ------------------------------------------------------------------------------ | |
| _db_init_done = False | |
| _db_init_lock = Lock() | |
| _should_init = os.getenv("RUN_INIT_DB", "0") == "1" | |
| def maybe_init_db(): | |
| global _db_init_done | |
| if _should_init and not _db_init_done: | |
| with _db_init_lock: | |
| if not _db_init_done: | |
| try: | |
| init_db() | |
| app.logger.info("Database initialised.") | |
| except Exception as e: | |
| app.logger.exception("DB init failed: %s", e) | |
| finally: | |
| _db_init_done = True | |
| # ------------------------------------------------------------------------------ | |
| # Cookie helpers | |
| # ------------------------------------------------------------------------------ | |
| def add_cookie(resp, name: str, value: str, max_age: int): | |
| """ | |
| In prod: Secure + SameSite=None + Partitioned (works with third-party cookie protections). | |
| In dev: SameSite=Lax, not Secure. | |
| """ | |
| if IS_PROD: | |
| resp.headers.add( | |
| "Set-Cookie", | |
| f"{name}={value}; Path=/; Max-Age={max_age}; Secure; HttpOnly; SameSite=None; Partitioned" | |
| ) | |
| else: | |
| resp.set_cookie(name, value, httponly=True, secure=False, samesite="Lax", max_age=max_age, path="/") | |
| # ------------------------------------------------------------------------------ | |
| # Health | |
| # ------------------------------------------------------------------------------ | |
| def health(): | |
| return {"status": "ok"}, 200 | |
| # ------------------------------------------------------------------------------ | |
| # Auth utilities | |
| # ------------------------------------------------------------------------------ | |
| from functools import wraps | |
| def token_required(f): | |
| def decorated(*args, **kwargs): | |
| token = request.cookies.get('access_token') | |
| if not token: | |
| return jsonify({"message": "Token is missing"}), 401 | |
| try: | |
| # Check blacklist | |
| conn = get_db_connection() | |
| cur = conn.cursor() | |
| cur.execute("SELECT token FROM BlacklistedTokens WHERE token = ?", (token,)) | |
| if cur.fetchone(): | |
| conn.close() | |
| return jsonify({"message": "Token has been revoked. Please log in again."}), 401 | |
| conn.close() | |
| data = jwt.decode(token, app.config['SECRET_KEY'], algorithms=["HS256"]) | |
| return f(data['username'], *args, **kwargs) | |
| except jwt.ExpiredSignatureError: | |
| return jsonify({"message": "Token has expired"}), 401 | |
| except jwt.InvalidTokenError: | |
| return jsonify({"message": "Invalid token"}), 401 | |
| except Exception as e: | |
| app.logger.exception("Auth error: %s", e) | |
| return jsonify({"message": "Server error"}), 500 | |
| return decorated | |
| # ------------------------------------------------------------------------------ | |
| # Routes (verification/auth only) | |
| # ------------------------------------------------------------------------------ | |
| def dashboard(username): | |
| return jsonify({"message": f"Welcome {username} to your dashboard!"}) | |
| def login(): | |
| data = request.json or {} | |
| username = data.get('username') | |
| password = data.get('password') | |
| try: | |
| conn = get_db_connection() | |
| cur = conn.cursor() | |
| cur.execute("SELECT password_hash FROM Users WHERE username = ?", (username,)) | |
| row = cur.fetchone() | |
| conn.close() | |
| except Exception as e: | |
| app.logger.exception("DB access error on login: %s", e) | |
| return jsonify({"message": "Database is unavailable"}), 503 | |
| if not row: | |
| return jsonify({"message": "Invalid credentials"}), 401 | |
| stored_hash = row[0] | |
| if not bcrypt.checkpw(password.encode('utf-8'), stored_hash.encode('utf-8')): | |
| return jsonify({"message": "Invalid credentials"}), 401 | |
| access_token = jwt.encode( | |
| {'username': username, 'exp': datetime.datetime.utcnow() + datetime.timedelta(minutes=15)}, | |
| app.config['SECRET_KEY'], | |
| algorithm="HS256" | |
| ) | |
| refresh_token = jwt.encode( | |
| {'username': username, 'exp': datetime.datetime.utcnow() + datetime.timedelta(days=7)}, | |
| app.config['SECRET_KEY'], | |
| algorithm="HS256" | |
| ) | |
| try: | |
| conn = get_db_connection() | |
| cur = conn.cursor() | |
| cur.execute("INSERT INTO RefreshTokens (username, token) VALUES (?, ?)", (username, refresh_token)) | |
| conn.commit() | |
| conn.close() | |
| except Exception as e: | |
| app.logger.exception("DB write error on login: %s", e) | |
| return jsonify({"message": "Database is unavailable"}), 503 | |
| resp = make_response(jsonify({"message": "Login successful"})) | |
| add_cookie(resp, 'access_token', access_token, 900) # 15 min | |
| add_cookie(resp, 'refresh_token', refresh_token, 7*24*60*60) # 7 days | |
| return resp | |
| def refresh(): | |
| refresh_token = request.cookies.get("refresh_token") | |
| if not refresh_token: | |
| return jsonify({'message': 'Refresh token is missing'}), 400 | |
| try: | |
| payload = jwt.decode(refresh_token, app.config['SECRET_KEY'], algorithms=["HS256"]) | |
| except jwt.ExpiredSignatureError: | |
| return jsonify({'message': 'Refresh token has expired'}), 401 | |
| except jwt.InvalidTokenError: | |
| return jsonify({'message': 'Invalid refresh token'}), 401 | |
| try: | |
| conn = get_db_connection() | |
| cur = conn.cursor() | |
| cur.execute("SELECT username FROM RefreshTokens WHERE token = ?", (refresh_token,)) | |
| row = cur.fetchone() | |
| conn.close() | |
| except Exception as e: | |
| app.logger.exception("DB access error on refresh: %s", e) | |
| return jsonify({"message": "Database is unavailable"}), 503 | |
| if not row: | |
| return jsonify({'message': 'Invalid refresh token'}), 401 | |
| username = row[0] | |
| new_access = jwt.encode( | |
| {'username': username, 'exp': datetime.datetime.utcnow() + datetime.timedelta(minutes=15)}, | |
| app.config['SECRET_KEY'], | |
| algorithm="HS256" | |
| ) | |
| resp = make_response(jsonify({'access_token': new_access})) | |
| add_cookie(resp, 'access_token', new_access, 900) | |
| return resp | |
| def logout(username): | |
| token = request.cookies.get('access_token') | |
| if not token: | |
| return jsonify({"message": "Invalid token format"}), 401 | |
| try: | |
| data = jwt.decode(token, app.config['SECRET_KEY'], algorithms=["HS256"]) | |
| username = data['username'] | |
| except jwt.ExpiredSignatureError: | |
| return jsonify({"message": "Token has expired"}), 401 | |
| except jwt.InvalidTokenError: | |
| return jsonify({"message": "Invalid token"}), 401 | |
| try: | |
| conn = get_db_connection() | |
| cur = conn.cursor() | |
| cur.execute("SELECT token FROM BlacklistedTokens WHERE token = ?", (token,)) | |
| if not cur.fetchone(): | |
| cur.execute("INSERT INTO BlacklistedTokens (token) VALUES (?)", (token,)) | |
| cur.execute("DELETE FROM RefreshTokens WHERE username = ?", (username,)) | |
| conn.commit() | |
| conn.close() | |
| except Exception as e: | |
| app.logger.exception("DB write error on logout: %s", e) | |
| return jsonify({"message": "Database is unavailable"}), 503 | |
| resp = make_response(jsonify({"message": "Logged out successfully!"})) | |
| resp.delete_cookie('access_token', path='/') | |
| resp.delete_cookie('refresh_token', path='/') | |
| return resp | |
| # @app.post("/upload-pdf") | |
| # def upload_pdf(): | |
| # file = request.files.get("pdf") | |
| # if not file: | |
| # return jsonify({"error": "No file uploaded"}), 400 | |
| # upload_folder = os.path.join(BASEDIR, "pdfs") | |
| # os.makedirs(upload_folder, exist_ok=True) | |
| # save_path = os.path.join(upload_folder, file.filename) | |
| # file.save(save_path) | |
| # # You can optionally trigger RAG indexing here | |
| # print(f"✅ PDF saved successfully at: {save_path}") | |
| # return jsonify({"message": "PDF uploaded successfully", "path": save_path}), 200 | |
| def upload_pdf(): | |
| file = request.files.get("pdf") | |
| if not file or file.filename.strip() == "": | |
| return jsonify({"error": "No file uploaded"}), 400 | |
| # Save to your backend's pdfs folder (BASEDIR/pdfs) | |
| upload_folder = os.path.join(BASEDIR, "pdfs") | |
| os.makedirs(upload_folder, exist_ok=True) | |
| filename = secure_filename(file.filename) | |
| save_path = os.path.join(upload_folder, filename) | |
| file.save(save_path) | |
| print(f"✅ PDF saved successfully at: {save_path}") | |
| # 🔔 Trigger RAG ingestion for THIS file (auto-ingest) | |
| RAG_INGEST_URL = os.getenv("RAG_INGEST_URL", "http://localhost:7000/rag/ingest") | |
| rag_result = {"status": "skipped"} | |
| try: | |
| payload = { | |
| "paths": [save_path], # ingest this single PDF | |
| # optional tags (use if you plan to filter in RAG later) | |
| "subject": "English", | |
| "grade": "5" | |
| } | |
| resp = requests.post(RAG_INGEST_URL, json=payload, timeout=30) | |
| resp.raise_for_status() | |
| rag_result = resp.json() | |
| print("✅ RAG ingest response:", rag_result) | |
| except Exception as e: | |
| # Do not fail the upload flow if ingest fails — just warn | |
| print("⚠️ RAG ingest failed:", e) | |
| rag_result = {"status": "warning", "message": f"RAG ingest failed: {str(e)}"} | |
| # Frontend already sets localStorage.hasPDF = 'true'; this response is for debugging/visibility | |
| return jsonify({ | |
| "message": "PDF uploaded successfully", | |
| "path": save_path, | |
| "rag": rag_result | |
| }), 200 | |
| def check_auth(username): | |
| return jsonify({"message": "Authenticated", "username": username}), 200 | |
| # ------------------------------------------------------------------------------ | |
| # Register Blueprint: grammar (and later media) lives in testmovie.py | |
| # ------------------------------------------------------------------------------ | |
| from chat import movie_bp # ensure testmovie.py defines movie_bp = Blueprint(...) | |
| from generateQuestion import questions_bp | |
| from reading import reading_bp | |
| from writting import writting_bp # match the exact file name on Linux | |
| from vocabularyBuilder import vocab_bp | |
| from findingword import finding_bp | |
| from listen import listen_bp | |
| from ragg.app import rag_bp | |
| from pron import pron_bp | |
| from pronvideo import pronvideo_bp | |
| from pronragg import pronragg_bp | |
| from pronragupgrade import pronragupgrade_bp | |
| from ragg.ingest_trigger import ingest_trigger_bp | |
| app.register_blueprint(movie_bp, url_prefix="/media") | |
| app.register_blueprint(questions_bp, url_prefix="/media") | |
| app.register_blueprint(reading_bp, url_prefix="/media") | |
| app.register_blueprint(writting_bp, url_prefix="/media") | |
| app.register_blueprint(vocab_bp, url_prefix="/media") | |
| app.register_blueprint(finding_bp, url_prefix="/media") | |
| app.register_blueprint(listen_bp, url_prefix="/media") | |
| app.register_blueprint(rag_bp, url_prefix="/rag") | |
| app.register_blueprint(ingest_trigger_bp, url_prefix="/rag") | |
| app.register_blueprint(pron_bp, url_prefix="/pron") | |
| app.register_blueprint(pronvideo_bp, url_prefix="/pronvideo") | |
| app.register_blueprint(pronragg_bp, url_prefix="/pronragg") | |
| app.register_blueprint(pronragupgrade_bp, url_prefix="/pronragupgrade") | |
| # app.register_blueprint(questions_bp, url_prefix="/media") # <-- add this | |
| # ------------------------------------------------------------------------------ | |
| # Local run (Gunicorn will import `verification:app` on Spaces) | |
| # ------------------------------------------------------------------------------ | |
| if __name__ == '__main__': | |
| port = int(os.getenv("PORT", "5000")) | |
| app.run(host="0.0.0.0", port=port, debug=True) | |