dineth554's picture
Upload backend/main.py with huggingface_hub
b22fdb6 verified
import os
import sys
import time
import logging
import shutil
import tempfile
from pathlib import Path
from typing import Optional, List
import uuid
from fastapi import FastAPI, UploadFile, File, Form, BackgroundTasks, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse
from pydantic import BaseModel, Field
import uvicorn
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(name)s: %(message)s')
logger = logging.getLogger("LegionAPI")
app = FastAPI(
title="LEGION Video Generation API",
description="High-quality video generation with text-to-video and image-to-video capabilities",
version="1.0.0",
)
# CORS middleware - allow all origins
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global generator instance (lazy initialized)
_generator = None
def get_generator():
global _generator
if _generator is None:
from inference import LegionVideoGenerator
_generator = LegionVideoGenerator()
logger.info(f"Generator initialized: device={_generator.device}, mock={_generator.mock_mode}")
return _generator
# --- Request/Response Models ---
class TextToVideoRequest(BaseModel):
prompt: str = Field(..., description="Text description of the video")
negative_prompt: str = Field(
default="warped, distorted, flickering, jittery, low quality, blurry, artifacts, ugly, deformed, bad anatomy, bad proportions",
description="Things to avoid in the video"
)
num_frames: int = Field(default=49, ge=1, le=129, description="Number of frames")
width: int = Field(default=480, ge=128, le=720, description="Video width")
height: int = Field(default=480, ge=128, le=720, description="Video height")
num_inference_steps: int = Field(default=50, ge=1, le=100, description="Inference steps")
guidance_scale: float = Field(default=6.0, ge=1.0, le=20.0, description="CFG scale")
watermark_strength: float = Field(default=0.3, ge=0.0, le=1.0, description="QWatermark opacity")
seed: Optional[int] = Field(default=None, description="Random seed")
class ImageToVideoRequest(BaseModel):
prompt: str = Field(..., description="Text description of motion/action")
negative_prompt: str = Field(default="warped, distorted, flickering, jittery, low quality, blurry, artifacts", description="Things to avoid")
num_frames: int = Field(default=49, ge=1, le=129)
width: int = Field(default=480, ge=128, le=720)
height: int = Field(default=480, ge=128, le=720)
num_inference_steps: int = Field(default=50, ge=1, le=100)
guidance_scale: float = Field(default=6.0, ge=1.0, le=20.0)
watermark_strength: float = Field(default=0.3, ge=0.0, le=1.0)
seed: Optional[int] = Field(default=None)
class StatusResponse(BaseModel):
status: str
model: str
gpu: bool
mock_mode: bool
device: str
version: str
# --- Endpoints ---
@app.get("/api/status", response_model=StatusResponse)
async def api_status():
gen = get_generator()
return StatusResponse(
status="ok",
model="LEGION Video Gen v1.0",
gpu=(gen.device == "cuda"),
mock_mode=gen.mock_mode,
device=gen.device,
version="1.0.0"
)
@app.get("/api/models")
async def api_models():
gen = get_generator()
return {"models": [{"id": "legion-t2v", "name": "LEGION Text-to-Video"}, {"id": "legion-i2v", "name": "LEGION Image-to-Video"}]}
@app.post("/api/generate/text")
async def generate_text(
request: TextToVideoRequest,
background_tasks: BackgroundTasks,
):
"""
Generate a video from a text prompt.
Returns the video file as an MP4 download.
"""
logger.info(f"API T2V: prompt='{request.prompt[:50]}...', frames={request.num_frames}")
try:
output_path = get_generator().generate_from_text(
prompt=request.prompt,
negative_prompt=request.negative_prompt,
num_frames=request.num_frames,
width=request.width,
height=request.height,
num_inference_steps=request.num_inference_steps,
guidance_scale=request.guidance_scale,
watermark_strength=request.watermark_strength,
seed=request.seed,
)
if not os.path.exists(output_path):
raise HTTPException(status_code=500, detail="Video generation failed - no output file")
filename = os.path.basename(output_path)
return FileResponse(
path=output_path,
media_type="video/mp4",
filename=filename,
headers={"Content-Disposition": f'attachment; filename="{filename}"'}
)
except Exception as e:
logger.error(f"T2V generation failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/generate/image")
async def generate_image(
file: UploadFile = File(...),
prompt: str = Form(...),
negative_prompt: str = Form("warped, distorted, flickering, jittery, low quality, blurry, artifacts"),
num_frames: int = Form(49),
width: int = Form(480),
height: int = Form(480),
num_inference_steps: int = Form(50),
guidance_scale: float = Form(6.0),
watermark_strength: float = Form(0.3),
seed: Optional[int] = Form(None),
):
"""
Generate a video from an uploaded image + text prompt.
"""
logger.info(f"API I2V: file='{file.filename}', prompt='{prompt[:50]}...'")
temp_dir = tempfile.mkdtemp()
try:
image_path = os.path.join(temp_dir, file.filename or "input_image.jpg")
with open(image_path, "wb") as f:
content = await file.read()
f.write(content)
output_path = get_generator().generate_from_image(
image_path=image_path,
prompt=prompt,
negative_prompt=negative_prompt,
num_frames=num_frames,
width=width,
height=height,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
watermark_strength=watermark_strength,
seed=seed,
)
if not os.path.exists(output_path):
raise HTTPException(status_code=500, detail="Video generation failed - no output file")
filename = os.path.basename(output_path)
return FileResponse(
path=output_path,
media_type="video/mp4",
filename=filename,
headers={"Content-Disposition": f'attachment; filename="{filename}"'}
)
except Exception as e:
logger.error(f"I2V generation failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
finally:
shutil.rmtree(temp_dir, ignore_errors=True)
@app.get("/")
async def root():
return {
"service": "LEGION Video Generation API",
"version": "1.0.0",
"docs": "/docs",
"endpoints": {
"status": "GET /api/status",
"models": "GET /api/models",
"text_to_video": "POST /api/generate/text",
"image_to_video": "POST /api/generate/image",
}
}
if __name__ == "__main__":
port = int(os.environ.get("BACKEND_PORT", 8081))
logger.info(f"Starting LEGION API server on port {port}")
uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")