| import os |
| os.environ['TRANSFORMERS_CACHE'] = '/app/.cache' |
|
|
| from flask import Flask, request, jsonify |
| from flask_cors import CORS |
| from flask_sock import Sock |
| import uuid |
| import time |
| import requests |
| from transformers import pipeline |
| from Crypto.Cipher import AES |
| from Crypto.Hash import SHA256 |
| import base64 |
| import threading |
|
|
| |
| app = Flask(__name__) |
| CORS(app) |
| sock = Sock(app) |
|
|
| |
| classifier = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-1") |
|
|
| |
| SESSIONS = {} |
|
|
| |
| SENSITIVE_LABELS = ["terrorism", "blackmail", "national security threat"] |
|
|
| |
| STORAGE_API = "https://mike23415-storage.hf.space/api/flag" |
|
|
| def decrypt_message(encrypted_b64, password): |
| try: |
| |
| pw_hash = SHA256.new(password.encode()).digest() |
| |
| encrypted = base64.b64decode(encrypted_b64) |
| |
| iv = encrypted[:12] |
| ciphertext = encrypted[12:] |
| |
| cipher = AES.new(pw_hash, AES.MODE_GCM, nonce=iv) |
| plaintext = cipher.decrypt(ciphertext).decode() |
| return plaintext |
| except Exception as e: |
| print(f"Decryption failed: {e}") |
| return None |
|
|
| def flag_if_sensitive(decrypted_text, ip, session_id, role, encrypted_msg): |
| if not decrypted_text: |
| return |
| |
| result = classifier(decrypted_text, SENSITIVE_LABELS) |
| scores = dict(zip(result["labels"], result["scores"])) |
| for label, score in scores.items(): |
| if score > 0.8: |
| print(f"⚠️ FLAGGED: {label} with score {score}") |
| |
| SESSIONS[session_id]["flagged"] = True |
| |
| flagged_entry = { |
| "encrypted_msg": encrypted_msg, |
| "decrypted_msg": decrypted_text, |
| "label": label, |
| "score": score, |
| "role": role, |
| "ip": ip, |
| "timestamp": time.time() |
| } |
| SESSIONS[session_id]["flagged_messages"].append(flagged_entry) |
| break |
|
|
| def log_flagged_session(session_id): |
| if session_id not in SESSIONS or not SESSIONS[session_id]["flagged"]: |
| return |
| session = SESSIONS[session_id] |
| payload = { |
| "session_id": session_id, |
| "created_at": session["created_at"], |
| "messages": session["messages"], |
| "unique_ips": list(set(msg["ip"] for msg in session["messages"])), |
| "flagged_messages": session["flagged_messages"] |
| } |
| try: |
| requests.post(STORAGE_API, json=payload, timeout=3) |
| print(f"Logged flagged session {session_id}") |
| except Exception as e: |
| print(f"Failed to log session {session_id}: {e}") |
|
|
| def cleanup_session(session_id): |
| if session_id in SESSIONS: |
| log_flagged_session(session_id) |
| del SESSIONS[session_id] |
| print(f"Deleted session {session_id}") |
|
|
| @app.route("/api/create_chat", methods=["POST"]) |
| def create_chat(): |
| data = request.get_json() |
| password = data.get("password", "default") |
| session_id = str(uuid.uuid4()) |
| SESSIONS[session_id] = { |
| "password": password, |
| "created_at": time.time(), |
| "messages": [], |
| "flagged": False, |
| "flagged_messages": [], |
| "connections": [] |
| } |
| |
| threading.Timer(900, cleanup_session, args=[session_id]).start() |
| short_id = session_id[:8] |
| short_url = f"https://{request.host}/s/{short_id}" |
| return jsonify({"session_id": session_id, "short_id": short_id, "short_url": short_url}) |
|
|
| @sock.route('/ws/<session_id>') |
| def chat(ws, session_id): |
| ip = request.remote_addr or "unknown" |
| if session_id not in SESSIONS: |
| ws.send('{"type": "error", "message": "Session not found"}') |
| ws.close() |
| return |
|
|
| |
| join_index = sum(1 for msg in SESSIONS[session_id]["messages"] if msg["role"].startswith("Receiver")) + 1 |
| role = "Sender" if len(SESSIONS[session_id]["messages"]) == 0 else f"Receiver {join_index}" |
| SESSIONS[session_id]["connections"].append(ws) |
|
|
| try: |
| while True: |
| msg = ws.receive() |
| if msg is None: |
| break |
| entry = { |
| "role": role, |
| "encrypted_msg": msg, |
| "ip": ip, |
| "timestamp": time.time() |
| } |
| SESSIONS[session_id]["messages"].append(entry) |
|
|
| |
| decrypted_text = decrypt_message(msg, SESSIONS[session_id]["password"]) |
| flag_if_sensitive(decrypted_text, ip, session_id, role, msg) |
|
|
| |
| for conn in SESSIONS[session_id]["connections"]: |
| try: |
| conn.send(f'{{"role": "{role}", "encrypted_msg": "{msg}"}}') |
| except: |
| continue |
| except Exception as e: |
| print(f"WebSocket error: {e}") |
| finally: |
| if ws in SESSIONS[session_id]["connections"]: |
| SESSIONS[session_id]["connections"].remove(ws) |
|
|
| @app.route("/") |
| def root(): |
| return "Real-time AI chat backend is running." |
|
|
| if __name__ == "__main__": |
| app.run(host="0.0.0.0", port=7860) |