Spaces:
Sleeping
Sleeping
| import os | |
| # CRITICAL: Set environment variables BEFORE any imports to prevent training | |
| os.environ['YOLO_VERBOSE'] = 'False' | |
| os.environ['ULTRALYTICS_AUTOINSTALL'] = 'False' | |
| # Force all HF caches to a writable place | |
| _cache = "/data/hf-cache" if os.getenv("HF_SPACE") else os.getenv("HF_CACHE_DIR", "/tmp/hf-cache") | |
| for var in ["HF_HOME", "HUGGINGFACE_HUB_CACHE", "HF_HUB_CACHE", "HF_CACHE_DIR", "XDG_CACHE_HOME"]: | |
| os.environ.setdefault(var, _cache) | |
| os.makedirs(_cache, exist_ok=True) | |
| from flask import Flask, render_template, request, jsonify, send_from_directory, url_for, Response | |
| from werkzeug.utils import secure_filename | |
| import os | |
| from PIL import Image | |
| import io | |
| import torch | |
| import cv2 | |
| import numpy as np | |
| from datetime import datetime | |
| from huggingface_hub import hf_hub_download | |
| import time | |
| from collections import deque | |
| import shutil | |
| app = Flask(__name__) | |
| app.config["UPLOAD_FOLDER"] = os.environ.get("UPLOAD_DIR", "/data/uploads") | |
| app.config["VIDEO_FOLDER"] = os.path.join(app.config["UPLOAD_FOLDER"], "videos") | |
| os.makedirs(app.config["UPLOAD_FOLDER"], exist_ok=True) | |
| os.makedirs(app.config["VIDEO_FOLDER"], exist_ok=True) | |
| # Exercise classes | |
| CLASSES = [ | |
| "benchpress", | |
| "deadlift", | |
| "squat", | |
| "leg_ext", | |
| "pushup", | |
| "shoulder_press" | |
| ] | |
| ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'webp'} | |
| # OPTIMIZED Performance settings | |
| SKIP_FRAMES = 4 | |
| TARGET_FPS = 15 | |
| INFERENCE_SIZE = 416 | |
| JPEG_QUALITY = 75 | |
| CONF_THRESHOLD = 0.25 | |
| IOU_THRESHOLD = 0.5 | |
| # Global variables | |
| model = None | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| frame_times = deque(maxlen=30) | |
| last_frame_cache = None | |
| def allowed_file(filename: str) -> bool: | |
| return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS | |
| def allowed_video(filename: str) -> bool: | |
| VIDEO_EXTENSIONS = {'mp4', 'avi', 'mov', 'mkv', 'webm'} | |
| return "." in filename and filename.rsplit(".", 1)[1].lower() in VIDEO_EXTENSIONS | |
| def load_model(): | |
| """Load the trained object detection model with STRICT anti-training safeguards""" | |
| global model | |
| print("\n" + "=" * 60) | |
| print("STARTING MODEL LOAD (INFERENCE-ONLY MODE)") | |
| print("=" * 60) | |
| # CRITICAL: Set anti-training environment variables | |
| os.environ['YOLO_VERBOSE'] = 'False' | |
| os.environ['ULTRALYTICS_AUTOINSTALL'] = 'False' | |
| try: | |
| # IMPORTANT: Update this with YOUR model repo | |
| if os.getenv("HF_SPACE"): | |
| print("Running in Hugging Face Space") | |
| # Download from your model repo | |
| checkpoint_path = hf_hub_download( | |
| repo_id="gym-vision/objdetection_model", # ← CHANGE THIS! | |
| filename="best_v4.pt", | |
| repo_type="model", | |
| cache_dir=os.environ["HF_CACHE_DIR"] | |
| ) | |
| else: | |
| checkpoint_path = "best_v4.pt" | |
| print(f"Local mode - Model at: {os.path.abspath(checkpoint_path)}") | |
| if not os.path.exists(checkpoint_path): | |
| raise FileNotFoundError(f"Model not found: {checkpoint_path}") | |
| print(f"Device: {device}") | |
| from ultralytics import YOLO | |
| # Load model | |
| model = YOLO(checkpoint_path) | |
| model.to(device) | |
| # Force evaluation mode | |
| if hasattr(model, 'model'): | |
| model.model.eval() | |
| model.model.requires_grad_(False) | |
| for param in model.model.parameters(): | |
| param.requires_grad = False | |
| # Disable trainer | |
| if hasattr(model, 'trainer'): | |
| model.trainer = None | |
| # Override ALL settings | |
| if hasattr(model, 'overrides'): | |
| model.overrides = { | |
| 'task': 'detect', | |
| 'mode': 'predict', | |
| 'model': checkpoint_path, | |
| 'data': None, | |
| 'epochs': 0, | |
| 'save': False, | |
| 'save_txt': False, | |
| 'save_conf': False, | |
| 'save_crop': False, | |
| 'show': False, | |
| 'plots': False, | |
| 'verbose': False, | |
| 'conf': CONF_THRESHOLD, | |
| 'iou': IOU_THRESHOLD, | |
| 'max_det': 10, | |
| 'half': device.type == 'cuda', | |
| 'device': device.type, | |
| 'augment': False, | |
| 'visualize': False, | |
| 'batch': 1, | |
| 'imgsz': INFERENCE_SIZE, | |
| 'workers': 0, | |
| } | |
| if hasattr(model, 'predictor'): | |
| model.predictor = None | |
| print("✓ Model loaded in INFERENCE-ONLY mode") | |
| # Warmup | |
| print("\nWarming up model...") | |
| dummy_img = np.random.randint(0, 255, (INFERENCE_SIZE, INFERENCE_SIZE, 3), dtype=np.uint8) | |
| with torch.no_grad(): | |
| try: | |
| _ = model(dummy_img, verbose=False) | |
| except: | |
| pass | |
| print("\n" + "=" * 60) | |
| print("MODEL READY FOR INFERENCE") | |
| print(f"Device: {device}") | |
| print("=" * 60 + "\n") | |
| return True | |
| except Exception as e: | |
| print("\n" + "=" * 60) | |
| print("MODEL LOADING FAILED") | |
| print(f"Error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| print("=" * 60 + "\n") | |
| model = None | |
| return False | |
| # Pre-define colors for faster lookup (BGR format) | |
| COLORS_BGR = { | |
| "benchpress": (107, 107, 255), | |
| "deadlift": (196, 205, 78), | |
| "squat": (209, 183, 69), | |
| "leg_ext": (122, 160, 255), | |
| "pushup": (200, 216, 152), | |
| "shoulder_press": (111, 220, 247) | |
| } | |
| def draw_detections_fast(image, detections): | |
| """Optimized drawing with smart label positioning""" | |
| if isinstance(image, Image.Image): | |
| image = np.array(image) | |
| img_h, img_w = image.shape[:2] | |
| font = cv2.FONT_HERSHEY_SIMPLEX | |
| font_scale = 0.6 | |
| thickness = 2 | |
| for det in detections: | |
| x1, y1, x2, y2 = det['bbox'] | |
| label = det['label'] | |
| conf = det['confidence'] | |
| color = COLORS_BGR.get(label, (255, 255, 255)) | |
| cv2.rectangle(image, (x1, y1), (x2, y2), color, 2) | |
| text = f"{label} {conf:.2f}" | |
| (text_w, text_h), _ = cv2.getTextSize(text, font, font_scale, thickness) | |
| label_margin = 8 | |
| if y1 - text_h - label_margin >= 0: | |
| label_y1 = y1 - text_h - label_margin | |
| label_y2 = y1 | |
| text_y = y1 - 4 | |
| elif y2 + text_h + label_margin <= img_h: | |
| label_y1 = y2 | |
| label_y2 = y2 + text_h + label_margin | |
| text_y = y2 + text_h + 2 | |
| else: | |
| label_y1 = y1 | |
| label_y2 = y1 + text_h + label_margin | |
| text_y = y1 + text_h + 2 | |
| label_x2 = min(x1 + text_w + 4, img_w) | |
| cv2.rectangle(image, (x1, label_y1), (label_x2, label_y2), color, -1) | |
| cv2.putText(image, text, (x1 + 2, text_y), font, font_scale, (0, 0, 0), thickness) | |
| return image | |
| def detect_objects_fast(image_array, verbose=False): | |
| """Optimized object detection""" | |
| if model is None: | |
| return [] | |
| try: | |
| start_time = time.time() | |
| detections = [] | |
| # Use model call | |
| results = model(image_array, verbose=False, imgsz=INFERENCE_SIZE) | |
| if results and len(results) > 0: | |
| result = results[0] | |
| if hasattr(result, 'boxes') and result.boxes is not None: | |
| boxes = result.boxes | |
| for box in boxes: | |
| xyxy = box.xyxy[0].cpu().numpy() | |
| x1, y1, x2, y2 = map(int, xyxy) | |
| conf = float(box.conf[0].cpu().numpy()) | |
| cls_id = int(box.cls[0].cpu().numpy()) | |
| label = model.names[cls_id] if hasattr(model, 'names') and cls_id < len(model.names) else CLASSES[cls_id] | |
| detections.append({ | |
| 'bbox': [x1, y1, x2, y2], | |
| 'label': label, | |
| 'confidence': conf | |
| }) | |
| inference_time = (time.time() - start_time) * 1000 | |
| if verbose: | |
| print(f"Inference: {inference_time:.1f}ms | Detections: {len(detections)}") | |
| return detections | |
| except Exception as e: | |
| print(f"Detection error: {e}") | |
| return [] | |
| def process_frame_optimized(frame, frame_count=0): | |
| """Optimized frame processing with caching""" | |
| global last_frame_cache | |
| if frame_count % SKIP_FRAMES != 0 and last_frame_cache is not None: | |
| return last_frame_cache['annotated'], last_frame_cache['detections'] | |
| rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| detections = detect_objects_fast(rgb_frame) | |
| annotated_frame = draw_detections_fast(rgb_frame.copy(), detections) | |
| last_frame_cache = { | |
| 'annotated': annotated_frame, | |
| 'detections': detections | |
| } | |
| return annotated_frame, detections | |
| def index(): | |
| return render_template("index.html") | |
| def uploaded_file(filename): | |
| return send_from_directory(app.config["UPLOAD_FOLDER"], filename) | |
| def webcam_feed(): | |
| """Note: Webcam will not work in Hugging Face Spaces (no camera access)""" | |
| def generate(): | |
| global last_frame_cache | |
| last_frame_cache = None | |
| if model is None: | |
| print("ERROR: Model not loaded") | |
| return | |
| cap = cv2.VideoCapture(0) | |
| if not cap.isOpened(): | |
| print("ERROR: Could not open webcam") | |
| return | |
| cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640) | |
| cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480) | |
| cap.set(cv2.CAP_PROP_FPS, 30) | |
| cap.set(cv2.CAP_PROP_BUFFERSIZE, 1) | |
| frame_count = 0 | |
| encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), JPEG_QUALITY] | |
| try: | |
| while True: | |
| success, frame = cap.read() | |
| if not success: | |
| break | |
| annotated_frame, detections = process_frame_optimized(frame, frame_count) | |
| _, buffer = cv2.imencode('.jpg', cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR), encode_param) | |
| frame_bytes = buffer.tobytes() | |
| frame_count += 1 | |
| yield (b'--frame\r\n' | |
| b'Content-Type: image/jpeg\r\n\r\n' + frame_bytes + b'\r\n') | |
| finally: | |
| cap.release() | |
| last_frame_cache = None | |
| return Response(generate(), mimetype='multipart/x-mixed-replace; boundary=frame') | |
| def analyze_image(): | |
| """Analyze uploaded image""" | |
| if model is None: | |
| return jsonify({"ok": False, "error": "Model not loaded"}), 500 | |
| if "image" not in request.files: | |
| return jsonify({"ok": False, "error": "No file part"}), 400 | |
| file = request.files["image"] | |
| if file.filename == "" or not allowed_file(file.filename): | |
| return jsonify({"ok": False, "error": "Invalid file"}), 400 | |
| try: | |
| image_bytes = file.read() | |
| filename = secure_filename(file.filename) | |
| filename = f"{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}_{filename}" | |
| save_path = os.path.join(app.config["UPLOAD_FOLDER"], filename) | |
| with open(save_path, 'wb') as f: | |
| f.write(image_bytes) | |
| image = Image.open(io.BytesIO(image_bytes)).convert('RGB') | |
| image_array = np.array(image) | |
| detections = detect_objects_fast(image_array, verbose=True) | |
| annotated_array = draw_detections_fast(image_array.copy(), detections) | |
| annotated_image = Image.fromarray(annotated_array) | |
| annotated_filename = f"annotated_{filename}" | |
| annotated_path = os.path.join(app.config["UPLOAD_FOLDER"], annotated_filename) | |
| annotated_image.save(annotated_path, quality=95) | |
| tips = { | |
| "benchpress": "Feet planted, slight arch, shoulder blades retracted; control bar path.", | |
| "deadlift": "Hinge at hips, bar close to shins, lats tight; push the floor, don't jerk.", | |
| "squat": "Keep knees tracking over toes; brace your core; maintain neutral spine.", | |
| "leg_ext": "Control the movement, don't swing; focus on squeezing the quadriceps.", | |
| "pushup": "Keep body straight, engage core; lower chest to floor with control.", | |
| "shoulder_press": "Keep core tight, don't arch back excessively; press straight up." | |
| } | |
| detected_exercises = list(set([d['label'] for d in detections])) | |
| exercise_tips = [tips.get(ex, "") for ex in detected_exercises] | |
| return jsonify({ | |
| "ok": True, | |
| "original_image": url_for("uploaded_file", filename=filename), | |
| "annotated_image": url_for("uploaded_file", filename=annotated_filename), | |
| "detections": detections, | |
| "tips": exercise_tips | |
| }) | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return jsonify({"ok": False, "error": str(e)}), 500 | |
| def upload_video(): | |
| """Upload video""" | |
| if model is None: | |
| return jsonify({"ok": False, "error": "Model not loaded"}), 500 | |
| if "video" not in request.files: | |
| return jsonify({"ok": False, "error": "No video file"}), 400 | |
| file = request.files["video"] | |
| if not file.filename or not allowed_video(file.filename): | |
| return jsonify({"ok": False, "error": "Invalid video"}), 400 | |
| filename = secure_filename(file.filename) | |
| filename = f"{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}_{filename}" | |
| save_path = os.path.join(app.config["VIDEO_FOLDER"], filename) | |
| file.save(save_path) | |
| return jsonify({"ok": True, "video_id": filename}) | |
| def video_feed(video_id): | |
| """Optimized video streaming""" | |
| global last_frame_cache | |
| if model is None: | |
| return jsonify({"ok": False, "error": "Model not loaded"}), 500 | |
| video_path = os.path.join(app.config["VIDEO_FOLDER"], video_id) | |
| def generate(): | |
| global last_frame_cache | |
| last_frame_cache = None | |
| cap = cv2.VideoCapture(video_path) | |
| frame_count = 0 | |
| encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), JPEG_QUALITY] | |
| while cap.isOpened(): | |
| success, frame = cap.read() | |
| if not success: | |
| break | |
| annotated_frame, detections = process_frame_optimized(frame, frame_count) | |
| _, buffer = cv2.imencode('.jpg', cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR), encode_param) | |
| frame_bytes = buffer.tobytes() | |
| frame_count += 1 | |
| yield (b'--frame\r\n' | |
| b'Content-Type: image/jpeg\r\n\r\n' + frame_bytes + b'\r\n') | |
| time.sleep(1.0 / TARGET_FPS) | |
| cap.release() | |
| last_frame_cache = None | |
| return Response(generate(), mimetype='multipart/x-mixed-replace; boundary=frame') | |
| # Load model on startup | |
| print("\n" + "="*60) | |
| print("FLASK APP STARTING") | |
| print("="*60) | |
| model_loaded = load_model() | |
| if model_loaded: | |
| print("\n✓ App ready for inference") | |
| print(f"Device: {device}") | |
| else: | |
| print("\n✗ Model failed to load") | |
| print("="*60 + "\n") | |
| if __name__ == "__main__": | |
| # IMPORTANT: Hugging Face Spaces requires port 7860 | |
| port = int(os.environ.get("PORT", 7860)) | |
| app.run(debug=False, host="0.0.0.0", port=port, threaded=True) |