py-learn-backend / verification.py
pykara's picture
Update verification.py
c6dc8f1 verified
# --- load .env FIRST ---
import os
from dotenv import load_dotenv
import requests
from werkzeug.utils import secure_filename
BASEDIR = os.path.abspath(os.path.dirname(__file__))
load_dotenv(os.path.join(BASEDIR, ".env")) # loads DB_USER, DB_PASSWORD, RUN_INIT_DB
import socket
import logging
from threading import Lock
from functools import wraps
import datetime
import bcrypt
import jwt
import pyodbc
from flask import Flask, request, jsonify, make_response, current_app
from flask_cors import CORS
# ------------------------------------------------------------------------------
# App, ENV, CORS
# ------------------------------------------------------------------------------
app = Flask(__name__)
app.config['SECRET_KEY'] = '96c63da06374c1bde332516f3acbd23c84f35f90d8a6321a25d790a0a451af32'
IS_PROD = os.getenv("ENV", "dev").lower() == "prod"
_origins = os.getenv("ALLOWED_ORIGINS", "http://localhost:4200")
ALLOWED_ORIGINS = [o.strip() for o in _origins.split(",") if o.strip()]
# CORS(app, supports_credentials=True, origins=ALLOWED_ORIGINS)
# Allow both localhost forms by default if env not set
_default_origins = "http://localhost:4200,http://127.0.0.1:4200"
_origins = os.getenv("ALLOWED_ORIGINS", _default_origins)
ALLOWED_ORIGINS = [o.strip() for o in _origins.split(",") if o.strip()]
CORS(
app,
resources={r"/*": {"origins": ALLOWED_ORIGINS}},
supports_credentials=True,
allow_headers=["Content-Type", "Authorization", "X-Requested-With", "X-User"],
expose_headers=["Set-Cookie"],
methods=["GET", "POST", "OPTIONS"]
)
def extract_username_from_request(req) -> str | None:
# 1) Header
hdr = req.headers.get("X-User")
if hdr:
return hdr
# 2) Body
data = req.get_json(silent=True) or {}
if data.get("username"):
return data.get("username")
# 3) JWT cookie from verification.py
token = req.cookies.get("access_token")
if token:
try:
payload = jwt.decode(token, current_app.config["SECRET_KEY"], algorithms=["HS256"])
return payload.get("username")
except jwt.ExpiredSignatureError:
return None
except jwt.InvalidTokenError:
return None
return None
@app.after_request
def add_cors_headers(resp):
origin = request.headers.get("Origin")
if origin and origin in ALLOWED_ORIGINS:
# echo the origin, never '*', when using credentials
resp.headers["Access-Control-Allow-Origin"] = origin
resp.headers["Vary"] = "Origin"
resp.headers["Access-Control-Allow-Credentials"] = "true"
resp.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Requested-With, X-User"
resp.headers["Access-Control-Allow-Methods"] = "GET, POST, OPTIONS"
return resp
@app.before_request
def handle_options_early():
if request.method == "OPTIONS":
resp = app.make_default_options_response()
origin = request.headers.get("Origin")
if origin and origin in ALLOWED_ORIGINS:
resp.headers["Access-Control-Allow-Origin"] = origin
resp.headers["Access-Control-Allow-Credentials"] = "true"
# Mirror browser's requested headers/methods
req_headers = request.headers.get("Access-Control-Request-Headers", "Content-Type, Authorization, X-Requested-With, X-User")
req_method = request.headers.get("Access-Control-Request-Method", "POST")
resp.headers["Access-Control-Allow-Headers"] = req_headers
resp.headers["Access-Control-Allow-Methods"] = req_method
return resp
logging.basicConfig(level=logging.INFO)
# NEW: API keys / shared config for blueprints (read from HF Secrets/ENV)
app.config["COHERE_API_KEY"] = os.getenv("COHERE_API_KEY", "")
# ------------------------------------------------------------------------------
# SQL Server configuration
# ------------------------------------------------------------------------------
# DB_SERVER = "pykara-sqlserver.cb60o04yk948.ap-south-1.rds.amazonaws.com,1433"
# DB_DATABASE = "AuthenticationDB1"
DB_SERVER = os.getenv("DB_SERVER", r"(localdb)\MSSQLLocalDB")
DB_DATABASE = os.getenv("DB_DATABASE", "AuthenticationDB1")
DB_DRIVER = os.getenv("DB_DRIVER", "ODBC Driver 17 for SQL Server") # 17 in your image
# Build connection string (FIXED)
is_local = (
DB_SERVER.lower().startswith("localhost")
or DB_SERVER.startswith(".")
or DB_SERVER.lower().startswith("(localdb)")
or "\\" in DB_SERVER
)
if is_local:
# Windows local / LocalDB using modern ODBC driver
CONN_STR = (
f"DRIVER={{{DB_DRIVER}}};"
f"SERVER={DB_SERVER};"
f"DATABASE={DB_DATABASE};"
"Trusted_Connection=yes;"
"TrustServerCertificate=yes;"
)
else:
# Remote SQL auth
CONN_STR = (
f"DRIVER={{{DB_DRIVER}}};"
f"SERVER={DB_SERVER};DATABASE={DB_DATABASE};"
f"UID={os.getenv('DB_USER')};PWD={os.getenv('DB_PASSWORD')};"
"Encrypt=yes;TrustServerCertificate=yes;"
)
# def get_db_connection():
# """Create a short-timeout connection. Fail clearly if secrets are missing."""
# if "Trusted_Connection=yes" not in CONN_STR:
# if not os.getenv("DB_USER") or not os.getenv("DB_PASSWORD"):
# raise RuntimeError("DB_USER/DB_PASSWORD are not set in the environment.")
# return pyodbc.connect(CONN_STR, timeout=5)
def get_db_connection():
"""Create a short-timeout connection. Fail clearly if secrets are missing."""
if "Trusted_Connection=yes" not in CONN_STR:
if not os.getenv("DB_USER") or not os.getenv("DB_PASSWORD"):
raise RuntimeError("DB_USER/DB_PASSWORD are not set in the environment.")
return pyodbc.connect(CONN_STR, timeout=5)
@app.get("/db/diag")
def db_diag():
info = {}
try:
info["drivers_found"] = pyodbc.drivers()
except Exception as e:
info["drivers_found_error"] = str(e)
# Resolve host part (before comma if "host,port")
host = DB_SERVER.split(",")[0].strip()
info["db_server_env"] = DB_SERVER
info["db_database_env"] = DB_DATABASE
info["db_driver_env"] = DB_DRIVER
try:
ip = socket.gethostbyname(host)
info["dns_lookup"] = {"host": host, "ip": ip}
except Exception as e:
info["dns_lookup"] = {"host": host, "error": str(e)}
try:
conn = get_db_connection()
conn.close()
info["connect"] = "ok"
except Exception as e:
info["connect"] = f"error: {e}"
return jsonify(info), 200
def init_db():
"""Create tables if they do not exist."""
conn = get_db_connection()
cur = conn.cursor()
cur.execute("""
IF OBJECT_ID('Users', 'U') IS NULL
CREATE TABLE Users (
id INT IDENTITY(1,1) PRIMARY KEY,
username NVARCHAR(100) UNIQUE NOT NULL,
password_hash NVARCHAR(500) NOT NULL,
role NVARCHAR(50) DEFAULT 'user'
)
""")
cur.execute("""
IF OBJECT_ID('BlacklistedTokens', 'U') IS NULL
CREATE TABLE BlacklistedTokens (
id INT IDENTITY(1,1) PRIMARY KEY,
token NVARCHAR(1000) UNIQUE NOT NULL,
created_at DATETIME DEFAULT GETDATE()
)
""")
cur.execute("""
IF OBJECT_ID('RefreshTokens', 'U') IS NULL
CREATE TABLE RefreshTokens (
id INT IDENTITY(1,1) PRIMARY KEY,
username NVARCHAR(100) NOT NULL,
token NVARCHAR(1000) UNIQUE NOT NULL,
created_at DATETIME DEFAULT GETDATE(),
FOREIGN KEY (username) REFERENCES Users(username) ON DELETE CASCADE
)
""")
conn.commit()
conn.close()
# ------------------------------------------------------------------------------
# One-time DB initialisation (Flask 3.x safe)
# ------------------------------------------------------------------------------
_db_init_done = False
_db_init_lock = Lock()
_should_init = os.getenv("RUN_INIT_DB", "0") == "1"
@app.before_request
def maybe_init_db():
global _db_init_done
if _should_init and not _db_init_done:
with _db_init_lock:
if not _db_init_done:
try:
init_db()
app.logger.info("Database initialised.")
except Exception as e:
app.logger.exception("DB init failed: %s", e)
finally:
_db_init_done = True
# ------------------------------------------------------------------------------
# Cookie helpers
# ------------------------------------------------------------------------------
def add_cookie(resp, name: str, value: str, max_age: int):
"""
In prod: Secure + SameSite=None + Partitioned (works with third-party cookie protections).
In dev: SameSite=Lax, not Secure.
"""
if IS_PROD:
resp.headers.add(
"Set-Cookie",
f"{name}={value}; Path=/; Max-Age={max_age}; Secure; HttpOnly; SameSite=None; Partitioned"
)
else:
resp.set_cookie(name, value, httponly=True, secure=False, samesite="Lax", max_age=max_age, path="/")
# ------------------------------------------------------------------------------
# Health
# ------------------------------------------------------------------------------
@app.get("/")
def health():
return {"status": "ok"}, 200
# ------------------------------------------------------------------------------
# Auth utilities
# ------------------------------------------------------------------------------
from functools import wraps
def token_required(f):
@wraps(f)
def decorated(*args, **kwargs):
token = request.cookies.get('access_token')
if not token:
return jsonify({"message": "Token is missing"}), 401
try:
# Check blacklist
conn = get_db_connection()
cur = conn.cursor()
cur.execute("SELECT token FROM BlacklistedTokens WHERE token = ?", (token,))
if cur.fetchone():
conn.close()
return jsonify({"message": "Token has been revoked. Please log in again."}), 401
conn.close()
data = jwt.decode(token, app.config['SECRET_KEY'], algorithms=["HS256"])
return f(data['username'], *args, **kwargs)
except jwt.ExpiredSignatureError:
return jsonify({"message": "Token has expired"}), 401
except jwt.InvalidTokenError:
return jsonify({"message": "Invalid token"}), 401
except Exception as e:
app.logger.exception("Auth error: %s", e)
return jsonify({"message": "Server error"}), 500
return decorated
# ------------------------------------------------------------------------------
# Routes (verification/auth only)
# ------------------------------------------------------------------------------
@app.get("/dashboard")
@token_required
def dashboard(username):
return jsonify({"message": f"Welcome {username} to your dashboard!"})
@app.post("/login")
def login():
data = request.json or {}
username = data.get('username')
password = data.get('password')
try:
conn = get_db_connection()
cur = conn.cursor()
cur.execute("SELECT password_hash FROM Users WHERE username = ?", (username,))
row = cur.fetchone()
conn.close()
except Exception as e:
app.logger.exception("DB access error on login: %s", e)
return jsonify({"message": "Database is unavailable"}), 503
if not row:
return jsonify({"message": "Invalid credentials"}), 401
stored_hash = row[0]
if not bcrypt.checkpw(password.encode('utf-8'), stored_hash.encode('utf-8')):
return jsonify({"message": "Invalid credentials"}), 401
access_token = jwt.encode(
{'username': username, 'exp': datetime.datetime.utcnow() + datetime.timedelta(minutes=15)},
app.config['SECRET_KEY'],
algorithm="HS256"
)
refresh_token = jwt.encode(
{'username': username, 'exp': datetime.datetime.utcnow() + datetime.timedelta(days=7)},
app.config['SECRET_KEY'],
algorithm="HS256"
)
try:
conn = get_db_connection()
cur = conn.cursor()
cur.execute("INSERT INTO RefreshTokens (username, token) VALUES (?, ?)", (username, refresh_token))
conn.commit()
conn.close()
except Exception as e:
app.logger.exception("DB write error on login: %s", e)
return jsonify({"message": "Database is unavailable"}), 503
resp = make_response(jsonify({"message": "Login successful"}))
add_cookie(resp, 'access_token', access_token, 900) # 15 min
add_cookie(resp, 'refresh_token', refresh_token, 7*24*60*60) # 7 days
return resp
@app.post("/refresh")
def refresh():
refresh_token = request.cookies.get("refresh_token")
if not refresh_token:
return jsonify({'message': 'Refresh token is missing'}), 400
try:
payload = jwt.decode(refresh_token, app.config['SECRET_KEY'], algorithms=["HS256"])
except jwt.ExpiredSignatureError:
return jsonify({'message': 'Refresh token has expired'}), 401
except jwt.InvalidTokenError:
return jsonify({'message': 'Invalid refresh token'}), 401
try:
conn = get_db_connection()
cur = conn.cursor()
cur.execute("SELECT username FROM RefreshTokens WHERE token = ?", (refresh_token,))
row = cur.fetchone()
conn.close()
except Exception as e:
app.logger.exception("DB access error on refresh: %s", e)
return jsonify({"message": "Database is unavailable"}), 503
if not row:
return jsonify({'message': 'Invalid refresh token'}), 401
username = row[0]
new_access = jwt.encode(
{'username': username, 'exp': datetime.datetime.utcnow() + datetime.timedelta(minutes=15)},
app.config['SECRET_KEY'],
algorithm="HS256"
)
resp = make_response(jsonify({'access_token': new_access}))
add_cookie(resp, 'access_token', new_access, 900)
return resp
@app.post("/logout")
@token_required
def logout(username):
token = request.cookies.get('access_token')
if not token:
return jsonify({"message": "Invalid token format"}), 401
try:
data = jwt.decode(token, app.config['SECRET_KEY'], algorithms=["HS256"])
username = data['username']
except jwt.ExpiredSignatureError:
return jsonify({"message": "Token has expired"}), 401
except jwt.InvalidTokenError:
return jsonify({"message": "Invalid token"}), 401
try:
conn = get_db_connection()
cur = conn.cursor()
cur.execute("SELECT token FROM BlacklistedTokens WHERE token = ?", (token,))
if not cur.fetchone():
cur.execute("INSERT INTO BlacklistedTokens (token) VALUES (?)", (token,))
cur.execute("DELETE FROM RefreshTokens WHERE username = ?", (username,))
conn.commit()
conn.close()
except Exception as e:
app.logger.exception("DB write error on logout: %s", e)
return jsonify({"message": "Database is unavailable"}), 503
resp = make_response(jsonify({"message": "Logged out successfully!"}))
resp.delete_cookie('access_token', path='/')
resp.delete_cookie('refresh_token', path='/')
return resp
# @app.post("/upload-pdf")
# def upload_pdf():
# file = request.files.get("pdf")
# if not file:
# return jsonify({"error": "No file uploaded"}), 400
# upload_folder = os.path.join(BASEDIR, "pdfs")
# os.makedirs(upload_folder, exist_ok=True)
# save_path = os.path.join(upload_folder, file.filename)
# file.save(save_path)
# # You can optionally trigger RAG indexing here
# print(f"✅ PDF saved successfully at: {save_path}")
# return jsonify({"message": "PDF uploaded successfully", "path": save_path}), 200
@app.post("/upload-pdf")
def upload_pdf():
file = request.files.get("pdf")
if not file or file.filename.strip() == "":
return jsonify({"error": "No file uploaded"}), 400
# Save to your backend's pdfs folder (BASEDIR/pdfs)
upload_folder = os.path.join(BASEDIR, "pdfs")
os.makedirs(upload_folder, exist_ok=True)
filename = secure_filename(file.filename)
save_path = os.path.join(upload_folder, filename)
file.save(save_path)
print(f"✅ PDF saved successfully at: {save_path}")
# 🔔 Trigger RAG ingestion for THIS file (auto-ingest)
RAG_INGEST_URL = os.getenv("RAG_INGEST_URL", "http://localhost:7000/rag/ingest")
rag_result = {"status": "skipped"}
try:
payload = {
"paths": [save_path], # ingest this single PDF
# optional tags (use if you plan to filter in RAG later)
"subject": "English",
"grade": "5"
}
resp = requests.post(RAG_INGEST_URL, json=payload, timeout=30)
resp.raise_for_status()
rag_result = resp.json()
print("✅ RAG ingest response:", rag_result)
except Exception as e:
# Do not fail the upload flow if ingest fails — just warn
print("⚠️ RAG ingest failed:", e)
rag_result = {"status": "warning", "message": f"RAG ingest failed: {str(e)}"}
# Frontend already sets localStorage.hasPDF = 'true'; this response is for debugging/visibility
return jsonify({
"message": "PDF uploaded successfully",
"path": save_path,
"rag": rag_result
}), 200
@app.get("/check-auth")
@token_required
def check_auth(username):
return jsonify({"message": "Authenticated", "username": username}), 200
# ------------------------------------------------------------------------------
# Register Blueprint: grammar (and later media) lives in testmovie.py
# ------------------------------------------------------------------------------
from chat import movie_bp # ensure testmovie.py defines movie_bp = Blueprint(...)
from generateQuestion import questions_bp
from reading import reading_bp
from writting import writting_bp # match the exact file name on Linux
from vocabularyBuilder import vocab_bp
from findingword import finding_bp
from listen import listen_bp
from ragg.app import rag_bp
from pron import pron_bp
from pronvideo import pronvideo_bp
from pronragg import pronragg_bp
from pronragupgrade import pronragupgrade_bp
from ragg.ingest_trigger import ingest_trigger_bp
app.register_blueprint(movie_bp, url_prefix="/media")
app.register_blueprint(questions_bp, url_prefix="/media")
app.register_blueprint(reading_bp, url_prefix="/media")
app.register_blueprint(writting_bp, url_prefix="/media")
app.register_blueprint(vocab_bp, url_prefix="/media")
app.register_blueprint(finding_bp, url_prefix="/media")
app.register_blueprint(listen_bp, url_prefix="/media")
app.register_blueprint(rag_bp, url_prefix="/rag")
app.register_blueprint(ingest_trigger_bp, url_prefix="/rag")
app.register_blueprint(pron_bp, url_prefix="/pron")
app.register_blueprint(pronvideo_bp, url_prefix="/pronvideo")
app.register_blueprint(pronragg_bp, url_prefix="/pronragg")
app.register_blueprint(pronragupgrade_bp, url_prefix="/pronragupgrade")
# app.register_blueprint(questions_bp, url_prefix="/media") # <-- add this
# ------------------------------------------------------------------------------
# Local run (Gunicorn will import `verification:app` on Spaces)
# ------------------------------------------------------------------------------
if __name__ == '__main__':
port = int(os.getenv("PORT", "5000"))
app.run(host="0.0.0.0", port=port, debug=True)