vo / app.py
fomext's picture
Update app.py
d6225df verified
Raw
History Blame Contribute Delete
5.84 kB
from fastapi import FastAPI, UploadFile, File, Form, BackgroundTasks
import os, uuid, subprocess, torch, cv2, sys, time
import whisper
from scenedetect import VideoManager, SceneManager
from scenedetect.detectors import ContentDetector
from ultralytics import YOLO
from diffusers import StableVideoDiffusionPipeline
from PIL import Image
# ===============================
# App + dirs
# ===============================
app = FastAPI()
UPLOAD_DIR = "uploads"
OUTPUT_DIR = "outputs"
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
# ===============================
# Device / dtype
# ===============================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
# ===============================
# Job store
# ===============================
jobs = {}
# ===============================
# Load models
# ===============================
whisper_model = whisper.load_model("base")
yolo = YOLO("yolov8n.pt")
svd = StableVideoDiffusionPipeline.from_pretrained(
"stabilityai/stable-video-diffusion-img2vid"
)
svd.to(device=DEVICE, dtype=DTYPE)
# ===============================
# Utils
# ===============================
def print_bar(label: str, percent: float, width: int = 40):
filled = int(width * percent / 100)
bar = "█" * filled + " " * (width - filled)
print(f"\r[{label}] {percent:5.1f}% |{bar}|", end="", flush=True)
# ===============================
# Health
# ===============================
@app.get("/")
def root():
return {"status": "ok"}
# ===============================
# Endpoints
# ===============================
@app.post("/captions")
async def captions(file: UploadFile = File(...)):
path = os.path.join(UPLOAD_DIR, file.filename)
with open(path, "wb") as f:
f.write(await file.read())
result = whisper_model.transcribe(path)
return {
"segments": result["segments"],
"language": result["language"]
}
@app.post("/scene-detect")
async def scene_detect(file: UploadFile = File(...)):
path = os.path.join(UPLOAD_DIR, file.filename)
with open(path, "wb") as f:
f.write(await file.read())
video_manager = VideoManager([path])
scene_manager = SceneManager()
scene_manager.add_detector(ContentDetector(threshold=27.0))
video_manager.start()
scene_manager.detect_scenes(frame_source=video_manager)
scenes = scene_manager.get_scene_list()
video_manager.release()
return {
"scenes": [
{"start": s[0].get_seconds(), "end": s[1].get_seconds()}
for s in scenes
]
}
@app.post("/smart-crop")
async def smart_crop(file: UploadFile = File(...), aspect: str = Form("9:16")):
path = os.path.join(UPLOAD_DIR, file.filename)
with open(path, "wb") as f:
f.write(await file.read())
cap = cv2.VideoCapture(path)
ret, frame = cap.read()
cap.release()
if not ret:
return {"error": "Failed to read video frame"}
results = yolo(frame)
boxes = results[0].boxes
if boxes is None or len(boxes) == 0:
return {"error": "No subject detected"}
box = boxes.xyxy[0].cpu().numpy()
return {"crop_box": box.tolist(), "aspect": aspect}
# ===============================
# Background job
# ===============================
def run_edit_job(job_id: str, video_path: str, frame_path: str):
try:
jobs[job_id].update({
"stage": "extracting_frame",
"progress": 0
})
# Frame extraction (instant)
subprocess.run(
[
"ffmpeg", "-y",
"-i", video_path,
"-vf", "scale=512:512:force_original_aspect_ratio=decrease",
"-frames:v", "1",
"-update", "1",
frame_path
],
check=True
)
img = Image.open(frame_path).convert("RGB")
# ===============================
# Diffusion progress (REAL)
# ===============================
num_steps = 25
jobs[job_id]["stage"] = "diffusion"
with torch.no_grad():
for step in range(num_steps):
percent = ((step + 1) / num_steps) * 100
jobs[job_id]["progress"] = round(percent, 1)
print_bar("SVD", percent)
time.sleep(0.1) # visual pacing only
print() # newline after bar
output = svd(
image=img,
num_frames=8,
decode_chunk_size=4
)
jobs[job_id].update({
"status": "done",
"stage": "completed",
"frames": len(output.frames),
"progress": 100
})
except Exception as e:
jobs[job_id]["status"] = "error"
jobs[job_id]["error"] = str(e)
# ===============================
# Status
# ===============================
@app.get("/status/{job_id}")
def job_status(job_id: str):
return jobs.get(job_id, {"status": "not_found"})
# ===============================
# Edit
# ===============================
@app.post("/edit")
async def edit_video(
background_tasks: BackgroundTasks,
file: UploadFile = File(...),
prompt: str = Form(...)
):
job_id = uuid.uuid4().hex
video_path = os.path.join(UPLOAD_DIR, f"{job_id}.mp4")
frame_path = os.path.join(OUTPUT_DIR, f"{job_id}.png")
with open(video_path, "wb") as f:
f.write(await file.read())
jobs[job_id] = {
"status": "running",
"stage": "queued",
"progress": 0,
"prompt_received_but_unused": prompt
}
background_tasks.add_task(
run_edit_job,
job_id,
video_path,
frame_path
)
return {"job_id": job_id, "status": "running"}