Spaces:
Sleeping
Sleeping
| import os | |
| import cv2 | |
| import torch | |
| import numpy as np | |
| from flask import Flask, jsonify, request | |
| from flask_cors import CORS | |
| from torchvision import transforms | |
| from dotenv import load_dotenv | |
| from timm import create_model | |
| from groq import Groq | |
| import time | |
| from pymongo import MongoClient | |
| from helpers import get_random_joke, get_one_fact | |
| from trivia import fetch_questionnaire, score_assessment, interpret_score | |
| load_dotenv() | |
| app = Flask(__name__) | |
| CORS(app, resources={r"/api/*": {"origins": "*"}}) | |
| MONGO_URI = os.getenv("MONGO_URI") | |
| client = MongoClient(MONGO_URI) | |
| db = client["mood_monitor"] | |
| sessions_col = db["sessions"] | |
| print("✅ Connected to MongoDB Atlas") | |
| MODEL_PATH = "best_fer2013_model_70.pth" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| NUM_CLASSES = 7 | |
| EMOTION_LABELS = ["Angry", "Disgust", "Fear", "Happy", "Sad", "Surprise", "Neutral"] | |
| print(f"⚙️ Using device: {DEVICE}") | |
| MODEL_LOADED = False | |
| last_face_time = 0 | |
| try: | |
| print("Loading model...") | |
| model = create_model("efficientnet_b0", pretrained=False, num_classes=NUM_CLASSES) | |
| model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE)) | |
| model = model.to(DEVICE) | |
| model.eval() | |
| MODEL_LOADED = True | |
| print("✅ Model loaded") | |
| except Exception as e: | |
| print("❌ Model load error:", e) | |
| GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
| groq_client = Groq(api_key=GROQ_API_KEY) if GROQ_API_KEY else None | |
| transform = transforms.Compose([ | |
| transforms.ToPILImage(), | |
| transforms.Grayscale(num_output_channels=3), | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_frontalface_default.xml") | |
| def detect_emotion_from_frame(face_img): | |
| if not MODEL_LOADED: | |
| return "Neutral", 0.0 | |
| try: | |
| face_tensor = transform(face_img).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| output = model(face_tensor) | |
| probs = torch.nn.functional.softmax(output, dim=1) | |
| conf, idx = torch.max(probs, 1) | |
| return EMOTION_LABELS[idx.item()], conf.item() | |
| except Exception as e: | |
| print("Inference error:", e) | |
| return "Neutral", 0.0 | |
| def health(): | |
| return jsonify({"status": "ok", "model_loaded": MODEL_LOADED}) | |
| def analyze_frame(): | |
| global last_face_time | |
| if "image" not in request.files: | |
| return jsonify({"error": "No image"}), 400 | |
| file = request.files["image"] | |
| img_bytes = np.frombuffer(file.read(), np.uint8) | |
| frame = cv2.imdecode(img_bytes, cv2.IMREAD_COLOR) | |
| if frame is None: | |
| return jsonify({"error": "Bad image"}), 400 | |
| gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) | |
| gray = cv2.equalizeHist(gray) | |
| faces = face_cascade.detectMultiScale(gray, 1.1, 4, minSize=(30, 30)) | |
| if len(faces) == 0: | |
| if time.time() - last_face_time < 1.5: | |
| return jsonify({"face_detected": True, "dominant_emotion": "", "emotions": {}}) | |
| return jsonify({"face_detected": False, "dominant_emotion": "", "emotions": {}}) | |
| last_face_time = time.time() | |
| x, y, w, h = max(faces, key=lambda r: r[2] * r[3]) | |
| face_roi = frame[y:y+h, x:x+w] | |
| emotion, confidence = detect_emotion_from_frame(face_roi) | |
| return jsonify({ | |
| "face_detected": True, | |
| "dominant_emotion": emotion, | |
| "emotions": { emotion.lower(): float(confidence) } | |
| }) | |
| def save_session(): | |
| try: | |
| data = request.get_json(force=True) | |
| print("📥 Received session data:", data) | |
| user_id = data.get("user_id") | |
| if not user_id: | |
| return jsonify({"error": "Missing user_id"}), 400 | |
| if not MONGO_URI or not client: | |
| return jsonify({"error": "Database not configured"}), 500 | |
| session_doc = { | |
| "user_id": str(user_id), | |
| "timestamp": time.time(), | |
| "duration": int(data.get("duration", 0)), | |
| "emotions": data.get("emotions", {}), | |
| "dominant_emotion": data.get("dominant_emotion", ""), | |
| } | |
| print("📝 Inserting into MongoDB:", session_doc) | |
| result = sessions_col.insert_one(session_doc) | |
| if not result.inserted_id: | |
| return jsonify({"error": "Failed to insert document"}), 500 | |
| return jsonify({"success": True, "id": str(result.inserted_id)}) | |
| except Exception as e: | |
| print("❌ SAVE SESSION ERROR:", str(e)) | |
| import traceback | |
| traceback.print_exc() | |
| return jsonify({"error": str(e)}), 500 | |
| def list_sessions(): | |
| user_id = request.args.get("user_id") | |
| if not user_id: | |
| return jsonify({"error": "Missing user_id"}), 400 | |
| sessions = list(sessions_col.find({"user_id": str(user_id)}, {"_id": 0})) | |
| return jsonify(sessions) | |
| def analytics(): | |
| user_id = request.args.get("user_id") | |
| if not user_id: | |
| return jsonify({"error": "Missing user_id"}), 400 | |
| sessions = list(sessions_col.find({"user_id": str(user_id)})) | |
| if not sessions: | |
| return jsonify({ | |
| "total_sessions": 0, | |
| "emotion_distribution": {}, | |
| "avg_duration": 0 | |
| }) | |
| emotion_counts = {} | |
| total_duration = 0 | |
| for s in sessions: | |
| total_duration += s.get("duration", 0) | |
| emo = s.get("dominant_emotion") | |
| if emo: | |
| emotion_counts[emo] = emotion_counts.get(emo, 0) + 1 | |
| return jsonify({ | |
| "total_sessions": len(sessions), | |
| "emotion_distribution": emotion_counts, | |
| "avg_duration": round(total_duration / len(sessions), 1) | |
| }) | |
| def joke(): | |
| return jsonify({"joke": get_random_joke()}) | |
| def fact(): | |
| return jsonify({"fact": get_one_fact(os.getenv("API_NINJAS_KEY"))}) | |
| def grok(): | |
| data = request.json | |
| question = data.get("question") | |
| if not groq_client: | |
| return jsonify({"reply": "AI not configured"}) | |
| completion = groq_client.chat.completions.create( | |
| messages=[{"role": "user", "content": question}], | |
| model="llama-3.3-70b-versatile" | |
| ) | |
| return jsonify({"reply": completion.choices[0].message.content}) | |
| if __name__ == "__main__": | |
| port = int(os.environ.get("PORT", 7860)) | |
| app.run(host="0.0.0.0", port=port) |