Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File | |
| import cv2 | |
| import torch | |
| import pandas as pd | |
| from PIL import Image | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| from tqdm import tqdm | |
| import json | |
| import shutil | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import HTMLResponse | |
| app = FastAPI() | |
| # Add CORS middleware to allow requests from localhost:8080 (or any origin you specify) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| # allow_origins=["http://localhost:8080"], # Replace with the URL of your Vue.js app | |
| allow_origins=["http://localhost:8080"], # Replace with the URL of your Vue.js app | |
| allow_credentials=True, | |
| allow_methods=["*"], # Allows all HTTP methods (GET, POST, etc.) | |
| allow_headers=["*"], # Allows all headers (such as Content-Type, Authorization, etc.) | |
| ) | |
| # Charger le processor et le modèle fine-tuné depuis le chemin local | |
| local_model_path = r'.\vit-finetuned-ucf101' | |
| processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") | |
| model = AutoModelForImageClassification.from_pretrained(local_model_path) | |
| # model = AutoModelForImageClassification.from_pretrained("2nzi/vit-finetuned-ucf101") | |
| model.eval() | |
| # Fonction pour classifier une image | |
| def classifier_image(image): | |
| image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) | |
| inputs = processor(images=image_pil, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| predicted_class_idx = logits.argmax(-1).item() | |
| predicted_class = model.config.id2label[predicted_class_idx] | |
| return predicted_class | |
| # Fonction pour traiter la vidéo et identifier les séquences de "Surfing" | |
| def identifier_sequences_surfing(video_path, intervalle=0.5): | |
| cap = cv2.VideoCapture(video_path) | |
| frame_rate = cap.get(cv2.CAP_PROP_FPS) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| frame_interval = int(frame_rate * intervalle) | |
| resultats = [] | |
| sequences_surfing = [] | |
| frame_index = 0 | |
| in_surf_sequence = False | |
| start_timestamp = None | |
| with tqdm(total=total_frames, desc="Traitement des frames de la vidéo", unit="frame") as pbar: | |
| success, frame = cap.read() | |
| while success: | |
| if frame_index % frame_interval == 0: | |
| timestamp = round(frame_index / frame_rate, 2) # Maintain precision to the centisecond level | |
| classe = classifier_image(frame) | |
| resultats.append({"Timestamp": timestamp, "Classe": classe}) | |
| if classe == "Surfing" and not in_surf_sequence: | |
| in_surf_sequence = True | |
| start_timestamp = timestamp | |
| elif classe != "Surfing" and in_surf_sequence: | |
| # Vérifier l'image suivante pour confirmer si c'était une erreur ponctuelle | |
| success_next, frame_next = cap.read() | |
| next_timestamp = round((frame_index + frame_interval) / frame_rate, 2) | |
| classe_next = None | |
| if success_next: | |
| classe_next = classifier_image(frame_next) | |
| resultats.append({"Timestamp": next_timestamp, "Classe": classe_next}) | |
| # Si l'image suivante est "Surfing", on ignore l'erreur ponctuelle | |
| if classe_next == "Surfing": | |
| success = success_next | |
| frame = frame_next | |
| frame_index += frame_interval | |
| pbar.update(frame_interval) | |
| continue | |
| else: | |
| # Sinon, terminer la séquence "Surfing" | |
| in_surf_sequence = False | |
| end_timestamp = timestamp | |
| sequences_surfing.append((start_timestamp, end_timestamp)) | |
| success, frame = cap.read() | |
| frame_index += 1 | |
| pbar.update(1) | |
| # Si on est toujours dans une séquence "Surfing" à la fin de la vidéo | |
| if in_surf_sequence: | |
| sequences_surfing.append((start_timestamp, round(frame_index / frame_rate, 2))) | |
| cap.release() | |
| dataframe_sequences = pd.DataFrame(sequences_surfing, columns=["Début", "Fin"]) | |
| return dataframe_sequences | |
| # Fonction pour convertir les séquences en format JSON | |
| def convertir_sequences_en_json(dataframe): | |
| events = [] | |
| blocks = [] | |
| for idx, row in dataframe.iterrows(): | |
| block = { | |
| "id": f"Surfing{idx + 1}", | |
| "start": round(row["Début"], 2), | |
| "end": round(row["Fin"], 2) | |
| } | |
| blocks.append(block) | |
| event = { | |
| "event": "Surfing", | |
| "blocks": blocks | |
| } | |
| events.append(event) | |
| return events | |
| async def analyze_video(file: UploadFile = File(...)): | |
| with open("uploaded_video.mp4", "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| dataframe_sequences = identifier_sequences_surfing("uploaded_video.mp4", intervalle=1) | |
| json_result = convertir_sequences_en_json(dataframe_sequences) | |
| return json_result | |
| async def index(): | |
| return ( | |
| """ | |
| <html> | |
| <body> | |
| <h1>Hello world!</h1> | |
| <p>This `/` is the most simple and default endpoint.</p> | |
| <p>If you want to learn more, check out the documentation of the API at | |
| <a href='/docs'>/docs</a> or | |
| <a href='https://2nzi-video-sequence-labeling.hf.space/docs' target='_blank'>external docs</a>. | |
| </p> | |
| </body> | |
| </html> | |
| """ | |
| ) | |
| # Lancer l'application avec uvicorn (command line) | |
| # uvicorn main:app --reload | |
| # http://localhost:8000/docs#/ | |
| # (.venv) PS C:\Users\antoi\Documents\Work_Learn\Labeling-Deploy\FastAPI> uvicorn main:app --host 0.0.0.0 --port 8000 --workers 1 |