File size: 13,883 Bytes
54ed165 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 | from flask import Flask, request, jsonify, send_file
from flask_cors import CORS
from gradio_client import Client
import os
import logging
from dotenv import load_dotenv
from datetime import datetime
import base64
from io import BytesIO
from PIL import Image
import tempfile
from models_config import (
VIDEO_MODELS,
CAMERA_MOVEMENTS,
VISUAL_EFFECTS,
VIDEO_STYLES,
EXAMPLE_PROMPTS,
get_model_info,
get_available_models,
build_enhanced_prompt
)
# Load environment variables
load_dotenv()
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('app.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
app = Flask(__name__)
CORS(app)
# Configuration
FLASK_PORT = int(os.getenv('FLASK_PORT', 5000))
FLASK_DEBUG = os.getenv('FLASK_DEBUG', 'False').lower() == 'true'
DEFAULT_MODEL = os.getenv('DEFAULT_MODEL', 'cogvideox-5b')
# Constants
MAX_PROMPT_LENGTH = 1000
MIN_PROMPT_LENGTH = 3
# Model clients cache
model_clients = {}
def get_or_create_client(model_id):
"""Get or create a Gradio client for the specified model"""
if model_id not in model_clients:
try:
model_info = get_model_info(model_id)
space_url = model_info['space_url']
logger.info(f"Initializing client for {model_id}: {space_url}")
# Try to connect with timeout
model_clients[model_id] = Client(space_url, verbose=False)
logger.info(f"Successfully connected to {model_id}")
except Exception as e:
logger.error(f"Failed to initialize client for {model_id}: {str(e)}")
logger.error(f"This might be because:")
logger.error(f" 1. The Hugging Face Space is not available or sleeping")
logger.error(f" 2. The Space URL has changed")
logger.error(f" 3. The Space requires authentication")
logger.error(f" 4. Network connectivity issues")
return None
return model_clients.get(model_id)
def validate_prompt(prompt):
"""Validate the input prompt"""
if not prompt or not isinstance(prompt, str):
return False, "Prompt must be a non-empty string"
prompt = prompt.strip()
if len(prompt) < MIN_PROMPT_LENGTH:
return False, f"Prompt must be at least {MIN_PROMPT_LENGTH} characters long"
if len(prompt) > MAX_PROMPT_LENGTH:
return False, f"Prompt must not exceed {MAX_PROMPT_LENGTH} characters"
return True, prompt
def decode_base64_image(base64_string):
"""Decode base64 image string to PIL Image"""
try:
# Remove data URL prefix if present
if ',' in base64_string:
base64_string = base64_string.split(',')[1]
image_data = base64.b64decode(base64_string)
image = Image.open(BytesIO(image_data))
return image
except Exception as e:
logger.error(f"Failed to decode image: {str(e)}")
return None
@app.route('/health', methods=['GET'])
def health_check():
"""Health check endpoint"""
return jsonify({
'status': 'healthy',
'timestamp': datetime.now().isoformat(),
'available_models': list(VIDEO_MODELS.keys()),
'default_model': DEFAULT_MODEL
})
@app.route('/models', methods=['GET'])
def list_models():
"""List all available video generation models"""
return jsonify({
'models': get_available_models(),
'camera_movements': CAMERA_MOVEMENTS,
'visual_effects': VISUAL_EFFECTS,
'video_styles': VIDEO_STYLES,
'example_prompts': EXAMPLE_PROMPTS
})
@app.route('/test-video', methods=['POST'])
def test_video():
"""Test endpoint that returns a sample video URL for UI testing"""
data = request.json
prompt = data.get('prompt', 'Test prompt')
logger.info(f"Test mode: Simulating video generation for: {prompt[:100]}")
# Return a sample video URL (Big Buck Bunny - open source test video)
return jsonify({
'video_url': 'https://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4',
'prompt': prompt,
'enhanced_prompt': prompt,
'model': 'test-mode',
'model_name': 'Test Mode (Demo Video)',
'timestamp': datetime.now().isoformat(),
'note': 'This is a demo video. Connect to Hugging Face Spaces for real generation.'
})
@app.route('/generate-video', methods=['POST'])
def generate_video():
"""Generate video from text prompt with advanced options"""
try:
# Validate request data
if not request.json:
return jsonify({'error': 'Request must be JSON'}), 400
data = request.json
base_prompt = data.get('prompt', '').strip()
model_id = data.get('model', DEFAULT_MODEL)
# Advanced options (Hailuo-inspired)
camera_movement = data.get('camera_movement', '')
visual_effect = data.get('visual_effect', '')
style = data.get('style', '')
# Validate prompt
is_valid, result = validate_prompt(base_prompt)
if not is_valid:
logger.warning(f"Invalid prompt: {result}")
return jsonify({'error': result}), 400
base_prompt = result
# Validate model
if model_id not in VIDEO_MODELS:
return jsonify({'error': f'Invalid model: {model_id}'}), 400
model_info = get_model_info(model_id)
# Check if model supports text-to-video
if model_info['type'] != 'text-to-video':
return jsonify({'error': f'Model {model_id} does not support text-to-video generation'}), 400
# Build enhanced prompt with camera movements and effects
enhanced_prompt = build_enhanced_prompt(base_prompt, camera_movement, visual_effect, style)
logger.info(f"Generating video with {model_id}")
logger.info(f"Base prompt: {base_prompt[:100]}...")
logger.info(f"Enhanced prompt: {enhanced_prompt[:150]}...")
# Handle demo mode specially
if model_id == 'demo':
logger.info("Demo mode activated - returning sample video")
return jsonify({
'video_url': 'https://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4',
'prompt': base_prompt,
'enhanced_prompt': enhanced_prompt,
'model': model_id,
'model_name': model_info['name'],
'timestamp': datetime.now().isoformat(),
'note': 'Demo mode: This is a sample video. Select a real model for AI generation.'
})
# Get or create client
client = get_or_create_client(model_id)
if client is None:
return jsonify({'error': 'Failed to connect to video generation service. Try using "Demo Mode" model to test the UI.'}), 503
# Generate video based on model type
try:
if model_id in ['cogvideox-5b', 'cogvideox-2b']:
# CogVideoX models - prompt with seed and other params
logger.info(f"Calling CogVideoX {model_id} with prompt: {enhanced_prompt[:100]}")
result = client.predict(
prompt=enhanced_prompt,
seed=0, # Random seed
api_name=model_info['api_name']
)
elif model_id == 'hunyuan-video':
# HunyuanVideo model
logger.info(f"Calling HunyuanVideo with prompt: {enhanced_prompt[:100]}")
result = client.predict(
enhanced_prompt,
api_name=model_info['api_name']
)
else:
# Generic approach for other models
logger.info(f"Calling {model_id} with generic approach")
result = client.predict(
enhanced_prompt,
api_name=model_info['api_name']
)
except Exception as e:
logger.error(f"Model API call failed: {str(e)}")
logger.error(f"This usually means:")
logger.error(f" 1. The Hugging Face Space is sleeping or unavailable")
logger.error(f" 2. The API has changed")
logger.error(f" 3. Try using 'Demo Mode' to test the UI")
return jsonify({'error': f'Video generation failed: {str(e)}. Try Demo Mode or a different model.'}), 500
# Extract video path/URL from result
video_path = result[0] if isinstance(result, list) else result
if not video_path:
logger.error("No video path returned from API")
return jsonify({'error': 'Failed to generate video. No output received.'}), 500
logger.info(f"Video generated successfully: {video_path}")
return jsonify({
'video_url': video_path,
'prompt': base_prompt,
'enhanced_prompt': enhanced_prompt,
'model': model_id,
'model_name': model_info['name'],
'timestamp': datetime.now().isoformat()
})
except ValueError as e:
logger.error(f"Validation error: {str(e)}")
return jsonify({'error': f'Invalid input: {str(e)}'}), 400
except ConnectionError as e:
logger.error(f"Connection error: {str(e)}")
return jsonify({'error': 'Failed to connect to video generation service. Please try again later.'}), 503
except TimeoutError as e:
logger.error(f"Timeout error: {str(e)}")
return jsonify({'error': 'Request timed out. The service may be busy. Please try again.'}), 504
except Exception as e:
logger.error(f"Unexpected error in generate_video: {str(e)}", exc_info=True)
return jsonify({'error': 'An unexpected error occurred. Please try again later.'}), 500
@app.route('/generate-video-from-image', methods=['POST'])
def generate_video_from_image():
"""Generate video from image with text prompt (Image-to-Video)"""
try:
# Validate request data
if not request.json:
return jsonify({'error': 'Request must be JSON'}), 400
data = request.json
prompt = data.get('prompt', '').strip()
image_data = data.get('image', '')
model_id = data.get('model', 'stable-video-diffusion')
# Validate model supports image-to-video
model_info = get_model_info(model_id)
if model_info['type'] != 'image-to-video':
return jsonify({'error': f'Model {model_id} does not support image-to-video generation'}), 400
# Decode image
image = decode_base64_image(image_data)
if image is None:
return jsonify({'error': 'Invalid image data'}), 400
# Save image to temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp_file:
image.save(tmp_file.name)
temp_image_path = tmp_file.name
logger.info(f"Generating video from image with {model_id}")
logger.info(f"Prompt: {prompt[:100]}...")
# Get or create client
client = get_or_create_client(model_id)
if client is None:
os.unlink(temp_image_path)
return jsonify({'error': 'Failed to connect to video generation service'}), 503
# Generate video
try:
if model_id == 'stable-video-diffusion':
result = client.predict(
temp_image_path,
api_name=model_info['api_name']
)
elif model_id == 'animatediff':
result = client.predict(
temp_image_path,
prompt,
api_name=model_info['api_name']
)
else:
result = client.predict(
temp_image_path,
prompt,
api_name=model_info['api_name']
)
finally:
# Clean up temp file
if os.path.exists(temp_image_path):
os.unlink(temp_image_path)
# Extract video path
video_path = result[0] if isinstance(result, list) else result
if not video_path:
return jsonify({'error': 'Failed to generate video from image'}), 500
logger.info(f"Video generated from image successfully")
return jsonify({
'video_url': video_path,
'prompt': prompt,
'model': model_id,
'model_name': model_info['name'],
'timestamp': datetime.now().isoformat()
})
except Exception as e:
logger.error(f"Error in generate_video_from_image: {str(e)}", exc_info=True)
return jsonify({'error': f'An error occurred: {str(e)}'}), 500
@app.errorhandler(404)
def not_found(e):
"""Handle 404 errors"""
return jsonify({'error': 'Endpoint not found'}), 404
@app.errorhandler(405)
def method_not_allowed(e):
"""Handle 405 errors"""
return jsonify({'error': 'Method not allowed'}), 405
@app.errorhandler(500)
def internal_error(e):
"""Handle 500 errors"""
logger.error(f"Internal server error: {str(e)}")
return jsonify({'error': 'Internal server error'}), 500
if __name__ == '__main__':
logger.info(f"Starting Enhanced Flask server on port {FLASK_PORT} (debug={FLASK_DEBUG})")
logger.info(f"Available models: {', '.join(VIDEO_MODELS.keys())}")
logger.info(f"Default model: {DEFAULT_MODEL}")
app.run(host='0.0.0.0', port=FLASK_PORT, debug=FLASK_DEBUG)
|