import os # CRITICAL: Set environment variables BEFORE any imports to prevent training os.environ['YOLO_VERBOSE'] = 'False' os.environ['ULTRALYTICS_AUTOINSTALL'] = 'False' # Force all HF caches to a writable place _cache = "/data/hf-cache" if os.getenv("HF_SPACE") else os.getenv("HF_CACHE_DIR", "/tmp/hf-cache") for var in ["HF_HOME", "HUGGINGFACE_HUB_CACHE", "HF_HUB_CACHE", "HF_CACHE_DIR", "XDG_CACHE_HOME"]: os.environ.setdefault(var, _cache) os.makedirs(_cache, exist_ok=True) from flask import Flask, render_template, request, jsonify, send_from_directory, url_for, Response from werkzeug.utils import secure_filename import os from PIL import Image import io import torch import cv2 import numpy as np from datetime import datetime from huggingface_hub import hf_hub_download import time from collections import deque import shutil app = Flask(__name__) app.config["UPLOAD_FOLDER"] = os.environ.get("UPLOAD_DIR", "/data/uploads") app.config["VIDEO_FOLDER"] = os.path.join(app.config["UPLOAD_FOLDER"], "videos") os.makedirs(app.config["UPLOAD_FOLDER"], exist_ok=True) os.makedirs(app.config["VIDEO_FOLDER"], exist_ok=True) # Exercise classes CLASSES = [ "benchpress", "deadlift", "squat", "leg_ext", "pushup", "shoulder_press" ] ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'webp'} # OPTIMIZED Performance settings SKIP_FRAMES = 4 TARGET_FPS = 15 INFERENCE_SIZE = 416 JPEG_QUALITY = 75 CONF_THRESHOLD = 0.25 IOU_THRESHOLD = 0.5 # Global variables model = None device = torch.device("cuda" if torch.cuda.is_available() else "cpu") frame_times = deque(maxlen=30) last_frame_cache = None def allowed_file(filename: str) -> bool: return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS def allowed_video(filename: str) -> bool: VIDEO_EXTENSIONS = {'mp4', 'avi', 'mov', 'mkv', 'webm'} return "." in filename and filename.rsplit(".", 1)[1].lower() in VIDEO_EXTENSIONS def load_model(): """Load the trained object detection model with STRICT anti-training safeguards""" global model print("\n" + "=" * 60) print("STARTING MODEL LOAD (INFERENCE-ONLY MODE)") print("=" * 60) # CRITICAL: Set anti-training environment variables os.environ['YOLO_VERBOSE'] = 'False' os.environ['ULTRALYTICS_AUTOINSTALL'] = 'False' try: # IMPORTANT: Update this with YOUR model repo if os.getenv("HF_SPACE"): print("Running in Hugging Face Space") # Download from your model repo checkpoint_path = hf_hub_download( repo_id="gym-vision/objdetection_model", # ← CHANGE THIS! filename="best_v4.pt", repo_type="model", cache_dir=os.environ["HF_CACHE_DIR"] ) else: checkpoint_path = "best_v4.pt" print(f"Local mode - Model at: {os.path.abspath(checkpoint_path)}") if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Model not found: {checkpoint_path}") print(f"Device: {device}") from ultralytics import YOLO # Load model model = YOLO(checkpoint_path) model.to(device) # Force evaluation mode if hasattr(model, 'model'): model.model.eval() model.model.requires_grad_(False) for param in model.model.parameters(): param.requires_grad = False # Disable trainer if hasattr(model, 'trainer'): model.trainer = None # Override ALL settings if hasattr(model, 'overrides'): model.overrides = { 'task': 'detect', 'mode': 'predict', 'model': checkpoint_path, 'data': None, 'epochs': 0, 'save': False, 'save_txt': False, 'save_conf': False, 'save_crop': False, 'show': False, 'plots': False, 'verbose': False, 'conf': CONF_THRESHOLD, 'iou': IOU_THRESHOLD, 'max_det': 10, 'half': device.type == 'cuda', 'device': device.type, 'augment': False, 'visualize': False, 'batch': 1, 'imgsz': INFERENCE_SIZE, 'workers': 0, } if hasattr(model, 'predictor'): model.predictor = None print("āœ“ Model loaded in INFERENCE-ONLY mode") # Warmup print("\nWarming up model...") dummy_img = np.random.randint(0, 255, (INFERENCE_SIZE, INFERENCE_SIZE, 3), dtype=np.uint8) with torch.no_grad(): try: _ = model(dummy_img, verbose=False) except: pass print("\n" + "=" * 60) print("MODEL READY FOR INFERENCE") print(f"Device: {device}") print("=" * 60 + "\n") return True except Exception as e: print("\n" + "=" * 60) print("MODEL LOADING FAILED") print(f"Error: {e}") import traceback traceback.print_exc() print("=" * 60 + "\n") model = None return False # Pre-define colors for faster lookup (BGR format) COLORS_BGR = { "benchpress": (107, 107, 255), "deadlift": (196, 205, 78), "squat": (209, 183, 69), "leg_ext": (122, 160, 255), "pushup": (200, 216, 152), "shoulder_press": (111, 220, 247) } def draw_detections_fast(image, detections): """Optimized drawing with smart label positioning""" if isinstance(image, Image.Image): image = np.array(image) img_h, img_w = image.shape[:2] font = cv2.FONT_HERSHEY_SIMPLEX font_scale = 0.6 thickness = 2 for det in detections: x1, y1, x2, y2 = det['bbox'] label = det['label'] conf = det['confidence'] color = COLORS_BGR.get(label, (255, 255, 255)) cv2.rectangle(image, (x1, y1), (x2, y2), color, 2) text = f"{label} {conf:.2f}" (text_w, text_h), _ = cv2.getTextSize(text, font, font_scale, thickness) label_margin = 8 if y1 - text_h - label_margin >= 0: label_y1 = y1 - text_h - label_margin label_y2 = y1 text_y = y1 - 4 elif y2 + text_h + label_margin <= img_h: label_y1 = y2 label_y2 = y2 + text_h + label_margin text_y = y2 + text_h + 2 else: label_y1 = y1 label_y2 = y1 + text_h + label_margin text_y = y1 + text_h + 2 label_x2 = min(x1 + text_w + 4, img_w) cv2.rectangle(image, (x1, label_y1), (label_x2, label_y2), color, -1) cv2.putText(image, text, (x1 + 2, text_y), font, font_scale, (0, 0, 0), thickness) return image @torch.no_grad() def detect_objects_fast(image_array, verbose=False): """Optimized object detection""" if model is None: return [] try: start_time = time.time() detections = [] # Use model call results = model(image_array, verbose=False, imgsz=INFERENCE_SIZE) if results and len(results) > 0: result = results[0] if hasattr(result, 'boxes') and result.boxes is not None: boxes = result.boxes for box in boxes: xyxy = box.xyxy[0].cpu().numpy() x1, y1, x2, y2 = map(int, xyxy) conf = float(box.conf[0].cpu().numpy()) cls_id = int(box.cls[0].cpu().numpy()) label = model.names[cls_id] if hasattr(model, 'names') and cls_id < len(model.names) else CLASSES[cls_id] detections.append({ 'bbox': [x1, y1, x2, y2], 'label': label, 'confidence': conf }) inference_time = (time.time() - start_time) * 1000 if verbose: print(f"Inference: {inference_time:.1f}ms | Detections: {len(detections)}") return detections except Exception as e: print(f"Detection error: {e}") return [] def process_frame_optimized(frame, frame_count=0): """Optimized frame processing with caching""" global last_frame_cache if frame_count % SKIP_FRAMES != 0 and last_frame_cache is not None: return last_frame_cache['annotated'], last_frame_cache['detections'] rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) detections = detect_objects_fast(rgb_frame) annotated_frame = draw_detections_fast(rgb_frame.copy(), detections) last_frame_cache = { 'annotated': annotated_frame, 'detections': detections } return annotated_frame, detections @app.route("/") def index(): return render_template("index.html") @app.route("/uploads/") def uploaded_file(filename): return send_from_directory(app.config["UPLOAD_FOLDER"], filename) @app.route("/webcam_feed") def webcam_feed(): """Note: Webcam will not work in Hugging Face Spaces (no camera access)""" def generate(): global last_frame_cache last_frame_cache = None if model is None: print("ERROR: Model not loaded") return cap = cv2.VideoCapture(0) if not cap.isOpened(): print("ERROR: Could not open webcam") return cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640) cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480) cap.set(cv2.CAP_PROP_FPS, 30) cap.set(cv2.CAP_PROP_BUFFERSIZE, 1) frame_count = 0 encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), JPEG_QUALITY] try: while True: success, frame = cap.read() if not success: break annotated_frame, detections = process_frame_optimized(frame, frame_count) _, buffer = cv2.imencode('.jpg', cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR), encode_param) frame_bytes = buffer.tobytes() frame_count += 1 yield (b'--frame\r\n' b'Content-Type: image/jpeg\r\n\r\n' + frame_bytes + b'\r\n') finally: cap.release() last_frame_cache = None return Response(generate(), mimetype='multipart/x-mixed-replace; boundary=frame') @app.route("/analyze_image", methods=["POST"]) def analyze_image(): """Analyze uploaded image""" if model is None: return jsonify({"ok": False, "error": "Model not loaded"}), 500 if "image" not in request.files: return jsonify({"ok": False, "error": "No file part"}), 400 file = request.files["image"] if file.filename == "" or not allowed_file(file.filename): return jsonify({"ok": False, "error": "Invalid file"}), 400 try: image_bytes = file.read() filename = secure_filename(file.filename) filename = f"{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}_{filename}" save_path = os.path.join(app.config["UPLOAD_FOLDER"], filename) with open(save_path, 'wb') as f: f.write(image_bytes) image = Image.open(io.BytesIO(image_bytes)).convert('RGB') image_array = np.array(image) detections = detect_objects_fast(image_array, verbose=True) annotated_array = draw_detections_fast(image_array.copy(), detections) annotated_image = Image.fromarray(annotated_array) annotated_filename = f"annotated_{filename}" annotated_path = os.path.join(app.config["UPLOAD_FOLDER"], annotated_filename) annotated_image.save(annotated_path, quality=95) tips = { "benchpress": "Feet planted, slight arch, shoulder blades retracted; control bar path.", "deadlift": "Hinge at hips, bar close to shins, lats tight; push the floor, don't jerk.", "squat": "Keep knees tracking over toes; brace your core; maintain neutral spine.", "leg_ext": "Control the movement, don't swing; focus on squeezing the quadriceps.", "pushup": "Keep body straight, engage core; lower chest to floor with control.", "shoulder_press": "Keep core tight, don't arch back excessively; press straight up." } detected_exercises = list(set([d['label'] for d in detections])) exercise_tips = [tips.get(ex, "") for ex in detected_exercises] return jsonify({ "ok": True, "original_image": url_for("uploaded_file", filename=filename), "annotated_image": url_for("uploaded_file", filename=annotated_filename), "detections": detections, "tips": exercise_tips }) except Exception as e: print(f"Error: {e}") import traceback traceback.print_exc() return jsonify({"ok": False, "error": str(e)}), 500 @app.route("/upload_video", methods=["POST"]) def upload_video(): """Upload video""" if model is None: return jsonify({"ok": False, "error": "Model not loaded"}), 500 if "video" not in request.files: return jsonify({"ok": False, "error": "No video file"}), 400 file = request.files["video"] if not file.filename or not allowed_video(file.filename): return jsonify({"ok": False, "error": "Invalid video"}), 400 filename = secure_filename(file.filename) filename = f"{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}_{filename}" save_path = os.path.join(app.config["VIDEO_FOLDER"], filename) file.save(save_path) return jsonify({"ok": True, "video_id": filename}) @app.route("/video_feed/") def video_feed(video_id): """Optimized video streaming""" global last_frame_cache if model is None: return jsonify({"ok": False, "error": "Model not loaded"}), 500 video_path = os.path.join(app.config["VIDEO_FOLDER"], video_id) def generate(): global last_frame_cache last_frame_cache = None cap = cv2.VideoCapture(video_path) frame_count = 0 encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), JPEG_QUALITY] while cap.isOpened(): success, frame = cap.read() if not success: break annotated_frame, detections = process_frame_optimized(frame, frame_count) _, buffer = cv2.imencode('.jpg', cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR), encode_param) frame_bytes = buffer.tobytes() frame_count += 1 yield (b'--frame\r\n' b'Content-Type: image/jpeg\r\n\r\n' + frame_bytes + b'\r\n') time.sleep(1.0 / TARGET_FPS) cap.release() last_frame_cache = None return Response(generate(), mimetype='multipart/x-mixed-replace; boundary=frame') # Load model on startup print("\n" + "="*60) print("FLASK APP STARTING") print("="*60) model_loaded = load_model() if model_loaded: print("\nāœ“ App ready for inference") print(f"Device: {device}") else: print("\nāœ— Model failed to load") print("="*60 + "\n") if __name__ == "__main__": # IMPORTANT: Hugging Face Spaces requires port 7860 port = int(os.environ.get("PORT", 7860)) app.run(debug=False, host="0.0.0.0", port=port, threaded=True)