derpiestme's picture
Update app.py
475c619 verified
import os
# CRITICAL: Set environment variables BEFORE any imports to prevent training
os.environ['YOLO_VERBOSE'] = 'False'
os.environ['ULTRALYTICS_AUTOINSTALL'] = 'False'
# Force all HF caches to a writable place
_cache = "/data/hf-cache" if os.getenv("HF_SPACE") else os.getenv("HF_CACHE_DIR", "/tmp/hf-cache")
for var in ["HF_HOME", "HUGGINGFACE_HUB_CACHE", "HF_HUB_CACHE", "HF_CACHE_DIR", "XDG_CACHE_HOME"]:
os.environ.setdefault(var, _cache)
os.makedirs(_cache, exist_ok=True)
from flask import Flask, render_template, request, jsonify, send_from_directory, url_for, Response
from werkzeug.utils import secure_filename
import os
from PIL import Image
import io
import torch
import cv2
import numpy as np
from datetime import datetime
from huggingface_hub import hf_hub_download
import time
from collections import deque
import shutil
app = Flask(__name__)
app.config["UPLOAD_FOLDER"] = os.environ.get("UPLOAD_DIR", "/data/uploads")
app.config["VIDEO_FOLDER"] = os.path.join(app.config["UPLOAD_FOLDER"], "videos")
os.makedirs(app.config["UPLOAD_FOLDER"], exist_ok=True)
os.makedirs(app.config["VIDEO_FOLDER"], exist_ok=True)
# Exercise classes
CLASSES = [
"benchpress",
"deadlift",
"squat",
"leg_ext",
"pushup",
"shoulder_press"
]
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'webp'}
# OPTIMIZED Performance settings
SKIP_FRAMES = 4
TARGET_FPS = 15
INFERENCE_SIZE = 416
JPEG_QUALITY = 75
CONF_THRESHOLD = 0.25
IOU_THRESHOLD = 0.5
# Global variables
model = None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
frame_times = deque(maxlen=30)
last_frame_cache = None
def allowed_file(filename: str) -> bool:
return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS
def allowed_video(filename: str) -> bool:
VIDEO_EXTENSIONS = {'mp4', 'avi', 'mov', 'mkv', 'webm'}
return "." in filename and filename.rsplit(".", 1)[1].lower() in VIDEO_EXTENSIONS
def load_model():
"""Load the trained object detection model with STRICT anti-training safeguards"""
global model
print("\n" + "=" * 60)
print("STARTING MODEL LOAD (INFERENCE-ONLY MODE)")
print("=" * 60)
# CRITICAL: Set anti-training environment variables
os.environ['YOLO_VERBOSE'] = 'False'
os.environ['ULTRALYTICS_AUTOINSTALL'] = 'False'
try:
# IMPORTANT: Update this with YOUR model repo
if os.getenv("HF_SPACE"):
print("Running in Hugging Face Space")
# Download from your model repo
checkpoint_path = hf_hub_download(
repo_id="gym-vision/objdetection_model", # ← CHANGE THIS!
filename="best_v4.pt",
repo_type="model",
cache_dir=os.environ["HF_CACHE_DIR"]
)
else:
checkpoint_path = "best_v4.pt"
print(f"Local mode - Model at: {os.path.abspath(checkpoint_path)}")
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Model not found: {checkpoint_path}")
print(f"Device: {device}")
from ultralytics import YOLO
# Load model
model = YOLO(checkpoint_path)
model.to(device)
# Force evaluation mode
if hasattr(model, 'model'):
model.model.eval()
model.model.requires_grad_(False)
for param in model.model.parameters():
param.requires_grad = False
# Disable trainer
if hasattr(model, 'trainer'):
model.trainer = None
# Override ALL settings
if hasattr(model, 'overrides'):
model.overrides = {
'task': 'detect',
'mode': 'predict',
'model': checkpoint_path,
'data': None,
'epochs': 0,
'save': False,
'save_txt': False,
'save_conf': False,
'save_crop': False,
'show': False,
'plots': False,
'verbose': False,
'conf': CONF_THRESHOLD,
'iou': IOU_THRESHOLD,
'max_det': 10,
'half': device.type == 'cuda',
'device': device.type,
'augment': False,
'visualize': False,
'batch': 1,
'imgsz': INFERENCE_SIZE,
'workers': 0,
}
if hasattr(model, 'predictor'):
model.predictor = None
print("✓ Model loaded in INFERENCE-ONLY mode")
# Warmup
print("\nWarming up model...")
dummy_img = np.random.randint(0, 255, (INFERENCE_SIZE, INFERENCE_SIZE, 3), dtype=np.uint8)
with torch.no_grad():
try:
_ = model(dummy_img, verbose=False)
except:
pass
print("\n" + "=" * 60)
print("MODEL READY FOR INFERENCE")
print(f"Device: {device}")
print("=" * 60 + "\n")
return True
except Exception as e:
print("\n" + "=" * 60)
print("MODEL LOADING FAILED")
print(f"Error: {e}")
import traceback
traceback.print_exc()
print("=" * 60 + "\n")
model = None
return False
# Pre-define colors for faster lookup (BGR format)
COLORS_BGR = {
"benchpress": (107, 107, 255),
"deadlift": (196, 205, 78),
"squat": (209, 183, 69),
"leg_ext": (122, 160, 255),
"pushup": (200, 216, 152),
"shoulder_press": (111, 220, 247)
}
def draw_detections_fast(image, detections):
"""Optimized drawing with smart label positioning"""
if isinstance(image, Image.Image):
image = np.array(image)
img_h, img_w = image.shape[:2]
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.6
thickness = 2
for det in detections:
x1, y1, x2, y2 = det['bbox']
label = det['label']
conf = det['confidence']
color = COLORS_BGR.get(label, (255, 255, 255))
cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
text = f"{label} {conf:.2f}"
(text_w, text_h), _ = cv2.getTextSize(text, font, font_scale, thickness)
label_margin = 8
if y1 - text_h - label_margin >= 0:
label_y1 = y1 - text_h - label_margin
label_y2 = y1
text_y = y1 - 4
elif y2 + text_h + label_margin <= img_h:
label_y1 = y2
label_y2 = y2 + text_h + label_margin
text_y = y2 + text_h + 2
else:
label_y1 = y1
label_y2 = y1 + text_h + label_margin
text_y = y1 + text_h + 2
label_x2 = min(x1 + text_w + 4, img_w)
cv2.rectangle(image, (x1, label_y1), (label_x2, label_y2), color, -1)
cv2.putText(image, text, (x1 + 2, text_y), font, font_scale, (0, 0, 0), thickness)
return image
@torch.no_grad()
def detect_objects_fast(image_array, verbose=False):
"""Optimized object detection"""
if model is None:
return []
try:
start_time = time.time()
detections = []
# Use model call
results = model(image_array, verbose=False, imgsz=INFERENCE_SIZE)
if results and len(results) > 0:
result = results[0]
if hasattr(result, 'boxes') and result.boxes is not None:
boxes = result.boxes
for box in boxes:
xyxy = box.xyxy[0].cpu().numpy()
x1, y1, x2, y2 = map(int, xyxy)
conf = float(box.conf[0].cpu().numpy())
cls_id = int(box.cls[0].cpu().numpy())
label = model.names[cls_id] if hasattr(model, 'names') and cls_id < len(model.names) else CLASSES[cls_id]
detections.append({
'bbox': [x1, y1, x2, y2],
'label': label,
'confidence': conf
})
inference_time = (time.time() - start_time) * 1000
if verbose:
print(f"Inference: {inference_time:.1f}ms | Detections: {len(detections)}")
return detections
except Exception as e:
print(f"Detection error: {e}")
return []
def process_frame_optimized(frame, frame_count=0):
"""Optimized frame processing with caching"""
global last_frame_cache
if frame_count % SKIP_FRAMES != 0 and last_frame_cache is not None:
return last_frame_cache['annotated'], last_frame_cache['detections']
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
detections = detect_objects_fast(rgb_frame)
annotated_frame = draw_detections_fast(rgb_frame.copy(), detections)
last_frame_cache = {
'annotated': annotated_frame,
'detections': detections
}
return annotated_frame, detections
@app.route("/")
def index():
return render_template("index.html")
@app.route("/uploads/<path:filename>")
def uploaded_file(filename):
return send_from_directory(app.config["UPLOAD_FOLDER"], filename)
@app.route("/webcam_feed")
def webcam_feed():
"""Note: Webcam will not work in Hugging Face Spaces (no camera access)"""
def generate():
global last_frame_cache
last_frame_cache = None
if model is None:
print("ERROR: Model not loaded")
return
cap = cv2.VideoCapture(0)
if not cap.isOpened():
print("ERROR: Could not open webcam")
return
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
cap.set(cv2.CAP_PROP_FPS, 30)
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
frame_count = 0
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), JPEG_QUALITY]
try:
while True:
success, frame = cap.read()
if not success:
break
annotated_frame, detections = process_frame_optimized(frame, frame_count)
_, buffer = cv2.imencode('.jpg', cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR), encode_param)
frame_bytes = buffer.tobytes()
frame_count += 1
yield (b'--frame\r\n'
b'Content-Type: image/jpeg\r\n\r\n' + frame_bytes + b'\r\n')
finally:
cap.release()
last_frame_cache = None
return Response(generate(), mimetype='multipart/x-mixed-replace; boundary=frame')
@app.route("/analyze_image", methods=["POST"])
def analyze_image():
"""Analyze uploaded image"""
if model is None:
return jsonify({"ok": False, "error": "Model not loaded"}), 500
if "image" not in request.files:
return jsonify({"ok": False, "error": "No file part"}), 400
file = request.files["image"]
if file.filename == "" or not allowed_file(file.filename):
return jsonify({"ok": False, "error": "Invalid file"}), 400
try:
image_bytes = file.read()
filename = secure_filename(file.filename)
filename = f"{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}_{filename}"
save_path = os.path.join(app.config["UPLOAD_FOLDER"], filename)
with open(save_path, 'wb') as f:
f.write(image_bytes)
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
image_array = np.array(image)
detections = detect_objects_fast(image_array, verbose=True)
annotated_array = draw_detections_fast(image_array.copy(), detections)
annotated_image = Image.fromarray(annotated_array)
annotated_filename = f"annotated_{filename}"
annotated_path = os.path.join(app.config["UPLOAD_FOLDER"], annotated_filename)
annotated_image.save(annotated_path, quality=95)
tips = {
"benchpress": "Feet planted, slight arch, shoulder blades retracted; control bar path.",
"deadlift": "Hinge at hips, bar close to shins, lats tight; push the floor, don't jerk.",
"squat": "Keep knees tracking over toes; brace your core; maintain neutral spine.",
"leg_ext": "Control the movement, don't swing; focus on squeezing the quadriceps.",
"pushup": "Keep body straight, engage core; lower chest to floor with control.",
"shoulder_press": "Keep core tight, don't arch back excessively; press straight up."
}
detected_exercises = list(set([d['label'] for d in detections]))
exercise_tips = [tips.get(ex, "") for ex in detected_exercises]
return jsonify({
"ok": True,
"original_image": url_for("uploaded_file", filename=filename),
"annotated_image": url_for("uploaded_file", filename=annotated_filename),
"detections": detections,
"tips": exercise_tips
})
except Exception as e:
print(f"Error: {e}")
import traceback
traceback.print_exc()
return jsonify({"ok": False, "error": str(e)}), 500
@app.route("/upload_video", methods=["POST"])
def upload_video():
"""Upload video"""
if model is None:
return jsonify({"ok": False, "error": "Model not loaded"}), 500
if "video" not in request.files:
return jsonify({"ok": False, "error": "No video file"}), 400
file = request.files["video"]
if not file.filename or not allowed_video(file.filename):
return jsonify({"ok": False, "error": "Invalid video"}), 400
filename = secure_filename(file.filename)
filename = f"{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}_{filename}"
save_path = os.path.join(app.config["VIDEO_FOLDER"], filename)
file.save(save_path)
return jsonify({"ok": True, "video_id": filename})
@app.route("/video_feed/<video_id>")
def video_feed(video_id):
"""Optimized video streaming"""
global last_frame_cache
if model is None:
return jsonify({"ok": False, "error": "Model not loaded"}), 500
video_path = os.path.join(app.config["VIDEO_FOLDER"], video_id)
def generate():
global last_frame_cache
last_frame_cache = None
cap = cv2.VideoCapture(video_path)
frame_count = 0
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), JPEG_QUALITY]
while cap.isOpened():
success, frame = cap.read()
if not success:
break
annotated_frame, detections = process_frame_optimized(frame, frame_count)
_, buffer = cv2.imencode('.jpg', cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR), encode_param)
frame_bytes = buffer.tobytes()
frame_count += 1
yield (b'--frame\r\n'
b'Content-Type: image/jpeg\r\n\r\n' + frame_bytes + b'\r\n')
time.sleep(1.0 / TARGET_FPS)
cap.release()
last_frame_cache = None
return Response(generate(), mimetype='multipart/x-mixed-replace; boundary=frame')
# Load model on startup
print("\n" + "="*60)
print("FLASK APP STARTING")
print("="*60)
model_loaded = load_model()
if model_loaded:
print("\n✓ App ready for inference")
print(f"Device: {device}")
else:
print("\n✗ Model failed to load")
print("="*60 + "\n")
if __name__ == "__main__":
# IMPORTANT: Hugging Face Spaces requires port 7860
port = int(os.environ.get("PORT", 7860))
app.run(debug=False, host="0.0.0.0", port=port, threaded=True)