|
|
""" |
|
|
Production API Endpoint |
|
|
Demonstrates complete Transformers + Safetensors integration with tier management |
|
|
""" |
|
|
|
|
|
from fastapi import FastAPI, HTTPException, BackgroundTasks |
|
|
from fastapi.responses import JSONResponse |
|
|
from pydantic import BaseModel, Field |
|
|
from typing import List, Optional, Dict, Any |
|
|
import logging |
|
|
import uuid |
|
|
from datetime import datetime |
|
|
import asyncio |
|
|
|
|
|
|
|
|
from core.scene_planner import get_planner, ScenePlanner |
|
|
from models.image.sd_generator import get_generator, SafeStableDiffusionGenerator |
|
|
from config.model_tiers import get_tier_config, validate_model_weights_security |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="Memo API - Transformers + Safetensors", |
|
|
description="Production-grade video generation API with proper ML security", |
|
|
version="2.0.0" |
|
|
) |
|
|
|
|
|
|
|
|
class VideoGenerationRequest(BaseModel): |
|
|
text: str = Field(..., description="Bangla text content") |
|
|
duration: int = Field(15, ge=5, le=60, description="Video duration in seconds") |
|
|
tier: str = Field("free", description="Model tier (free, pro, enterprise)") |
|
|
style: Optional[str] = Field(None, description="Visual style preference") |
|
|
|
|
|
class Config: |
|
|
schema_extra = { |
|
|
"example": { |
|
|
"text": "আজকের দিনটি খুব সুন্দর ছিল। রোদ উজ্জ্বল এবং হাওয়া মৃদুমন্দ।", |
|
|
"duration": 15, |
|
|
"tier": "pro", |
|
|
"style": "realistic" |
|
|
} |
|
|
} |
|
|
|
|
|
class SceneModel(BaseModel): |
|
|
id: int |
|
|
description: str |
|
|
duration: float |
|
|
start_time: float |
|
|
end_time: float |
|
|
visual_style: str |
|
|
transition_type: str |
|
|
|
|
|
class GenerationStatus(BaseModel): |
|
|
request_id: str |
|
|
status: str |
|
|
progress: float = Field(0.0, ge=0.0, le=100.0) |
|
|
message: Optional[str] = None |
|
|
scenes: Optional[List[SceneModel]] = None |
|
|
created_at: datetime |
|
|
updated_at: datetime |
|
|
|
|
|
class VideoGenerationResponse(BaseModel): |
|
|
request_id: str |
|
|
status: str |
|
|
message: str |
|
|
tier_used: str |
|
|
scenes_count: int |
|
|
estimated_duration: float |
|
|
credits_used: float |
|
|
security_compliant: bool |
|
|
|
|
|
|
|
|
generation_status = {} |
|
|
tier_managers = {} |
|
|
|
|
|
|
|
|
def initialize_tier_managers(): |
|
|
"""Initialize model managers for different tiers.""" |
|
|
tiers = ["free", "pro", "enterprise"] |
|
|
|
|
|
for tier_name in tiers: |
|
|
try: |
|
|
tier_config = get_tier_config(tier_name) |
|
|
if tier_config: |
|
|
logger.info(f"Initializing {tier_name} tier...") |
|
|
|
|
|
|
|
|
scene_planner = ScenePlanner(tier_config.text_model_id) |
|
|
|
|
|
|
|
|
image_generator = SafeStableDiffusionGenerator( |
|
|
model_id=tier_config.image_model_id, |
|
|
lora_path=tier_config.lora_path, |
|
|
use_lcm=tier_config.lcm_enabled |
|
|
) |
|
|
|
|
|
tier_managers[tier_name] = { |
|
|
"scene_planner": scene_planner, |
|
|
"image_generator": image_generator, |
|
|
"config": tier_config |
|
|
} |
|
|
|
|
|
logger.info(f"{tier_name} tier initialized successfully") |
|
|
else: |
|
|
logger.warning(f"No configuration found for tier: {tier_name}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to initialize {tier_name} tier: {e}") |
|
|
|
|
|
|
|
|
async def process_video_generation(request_id: str, request: VideoGenerationRequest): |
|
|
"""Background task for video generation.""" |
|
|
try: |
|
|
status = generation_status[request_id] |
|
|
status.status = "processing" |
|
|
status.progress = 10.0 |
|
|
status.message = "Initializing models..." |
|
|
status.updated_at = datetime.now() |
|
|
|
|
|
|
|
|
tier_config = get_tier_config(request.tier) |
|
|
if not tier_config: |
|
|
raise ValueError(f"Invalid tier: {request.tier}") |
|
|
|
|
|
tier_manager = tier_managers.get(request.tier) |
|
|
if not tier_manager: |
|
|
raise ValueError(f"Tier manager not available: {request.tier}") |
|
|
|
|
|
status.progress = 20.0 |
|
|
status.message = "Planning scenes..." |
|
|
|
|
|
|
|
|
scenes = tier_manager["scene_planner"].plan_scenes( |
|
|
text_bn=request.text, |
|
|
duration=request.duration |
|
|
) |
|
|
|
|
|
status.scenes = [SceneModel(**scene) for scene in scenes] |
|
|
status.progress = 40.0 |
|
|
status.message = "Generating frames..." |
|
|
|
|
|
|
|
|
generated_frames = [] |
|
|
for i, scene in enumerate(scenes): |
|
|
status.message = f"Generating frame {i+1}/{len(scenes)}..." |
|
|
status.progress = 40.0 + (30.0 * (i + 1) / len(scenes)) |
|
|
|
|
|
|
|
|
frames = tier_manager["image_generator"].generate_frames( |
|
|
prompt=scene["description"], |
|
|
frames=1, |
|
|
width=tier_config.image_width, |
|
|
height=tier_config.image_height, |
|
|
num_inference_steps=tier_config.image_inference_steps, |
|
|
guidance_scale=tier_config.image_guidance_scale |
|
|
) |
|
|
|
|
|
if frames: |
|
|
generated_frames.extend(frames) |
|
|
|
|
|
|
|
|
await asyncio.sleep(0.1) |
|
|
|
|
|
status.progress = 80.0 |
|
|
status.message = "Finalizing generation..." |
|
|
|
|
|
|
|
|
security_results = [] |
|
|
if tier_config.lora_path: |
|
|
security_result = validate_model_weights_security(tier_config.lora_path) |
|
|
security_results.append(security_result) |
|
|
|
|
|
|
|
|
status.status = "completed" |
|
|
status.progress = 100.0 |
|
|
status.message = f"Generated {len(generated_frames)} frames successfully" |
|
|
status.updated_at = datetime.now() |
|
|
|
|
|
logger.info(f"Video generation completed for request {request_id}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Video generation failed for request {request_id}: {e}") |
|
|
status = generation_status[request_id] |
|
|
status.status = "failed" |
|
|
status.message = f"Generation failed: {str(e)}" |
|
|
status.updated_at = datetime.now() |
|
|
|
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
"""Initialize the application.""" |
|
|
logger.info("Starting Memo API with Transformers + Safetensors") |
|
|
initialize_tier_managers() |
|
|
logger.info("Application initialized successfully") |
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
"""Health check endpoint.""" |
|
|
return { |
|
|
"status": "healthy", |
|
|
"version": "2.0.0", |
|
|
"transformers_version": "4.40.0+", |
|
|
"safetensors_enabled": True, |
|
|
"available_tiers": list(tier_managers.keys()) |
|
|
} |
|
|
|
|
|
@app.get("/tiers") |
|
|
async def list_tiers(): |
|
|
"""List available model tiers.""" |
|
|
return { |
|
|
"tiers": [ |
|
|
{ |
|
|
"name": tier_name, |
|
|
"config": { |
|
|
"description": manager["config"].description, |
|
|
"max_scenes": manager["config"].text_max_scenes, |
|
|
"image_resolution": f"{manager['config'].image_width}x{manager['config'].image_height}", |
|
|
"lora_enabled": manager["config"].lora_path is not None, |
|
|
"lcm_enabled": manager["config"].lcm_enabled, |
|
|
"credits_per_minute": manager["config"].credits_per_minute |
|
|
} |
|
|
} |
|
|
for tier_name, manager in tier_managers.items() |
|
|
] |
|
|
} |
|
|
|
|
|
@app.post("/generate", response_model=VideoGenerationResponse) |
|
|
async def generate_video( |
|
|
request: VideoGenerationRequest, |
|
|
background_tasks: BackgroundTasks |
|
|
): |
|
|
""" |
|
|
Generate video content using transformer models and safetensors. |
|
|
|
|
|
This endpoint demonstrates the complete integration: |
|
|
- Bangla text parsing using Transformers |
|
|
- Scene planning with ML-based logic |
|
|
- Image generation with Stable Diffusion + Safetensors |
|
|
- Proper security validation |
|
|
- Tier-based resource management |
|
|
""" |
|
|
try: |
|
|
|
|
|
if not request.text.strip(): |
|
|
raise HTTPException(status_code=400, detail="Text content cannot be empty") |
|
|
|
|
|
tier_config = get_tier_config(request.tier) |
|
|
if not tier_config: |
|
|
raise HTTPException(status_code=400, detail=f"Invalid tier: {request.tier}") |
|
|
|
|
|
tier_manager = tier_managers.get(request.tier) |
|
|
if not tier_manager: |
|
|
raise HTTPException(status_code=500, detail=f"Tier {request.tier} not available") |
|
|
|
|
|
|
|
|
request_id = str(uuid.uuid4()) |
|
|
|
|
|
|
|
|
generation_status[request_id] = GenerationStatus( |
|
|
request_id=request_id, |
|
|
status="pending", |
|
|
created_at=datetime.now(), |
|
|
updated_at=datetime.now() |
|
|
) |
|
|
|
|
|
|
|
|
background_tasks.add_task(process_video_generation, request_id, request) |
|
|
|
|
|
|
|
|
estimated_duration = request.duration |
|
|
credits_used = (estimated_duration / 60.0) * tier_config.credits_per_minute |
|
|
|
|
|
|
|
|
security_compliant = True |
|
|
if tier_config.lora_path: |
|
|
security_result = validate_model_weights_security(tier_config.lora_path) |
|
|
security_compliant = security_result["is_secure"] |
|
|
|
|
|
response = VideoGenerationResponse( |
|
|
request_id=request_id, |
|
|
status="processing", |
|
|
message="Video generation started", |
|
|
tier_used=request.tier, |
|
|
scenes_count=tier_config.text_max_scenes, |
|
|
estimated_duration=estimated_duration, |
|
|
credits_used=credits_used, |
|
|
security_compliant=security_compliant |
|
|
) |
|
|
|
|
|
logger.info(f"Video generation started for request {request_id} (tier: {request.tier})") |
|
|
return response |
|
|
|
|
|
except HTTPException: |
|
|
raise |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to start video generation: {e}") |
|
|
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") |
|
|
|
|
|
@app.get("/status/{request_id}", response_model=GenerationStatus) |
|
|
async def get_generation_status(request_id: str): |
|
|
"""Get the status of a video generation request.""" |
|
|
if request_id not in generation_status: |
|
|
raise HTTPException(status_code=404, detail="Request not found") |
|
|
|
|
|
return generation_status[request_id] |
|
|
|
|
|
@app.get("/models/info") |
|
|
async def get_models_info(): |
|
|
"""Get information about loaded models.""" |
|
|
models_info = {} |
|
|
|
|
|
for tier_name, manager in tier_managers.items(): |
|
|
try: |
|
|
scene_planner = manager["scene_planner"] |
|
|
image_generator = manager["image_generator"] |
|
|
config = manager["config"] |
|
|
|
|
|
models_info[tier_name] = { |
|
|
"text_model": { |
|
|
"model_id": config.text_model_id, |
|
|
"max_scenes": config.text_max_scenes, |
|
|
"device": scene_planner.parser.device |
|
|
}, |
|
|
"image_model": { |
|
|
"model_id": config.image_model_id, |
|
|
"resolution": f"{config.image_width}x{config.image_height}", |
|
|
"inference_steps": config.image_inference_steps, |
|
|
"lora_path": config.lora_path, |
|
|
"lcm_enabled": config.lcm_enabled |
|
|
}, |
|
|
"security": { |
|
|
"safetensors_only": config.safetensors_only, |
|
|
"model_signatures_required": config.model_signatures_required |
|
|
} |
|
|
} |
|
|
except Exception as e: |
|
|
models_info[tier_name] = {"error": str(e)} |
|
|
|
|
|
return {"models": models_info} |
|
|
|
|
|
@app.post("/security/validate") |
|
|
async def validate_security(model_path: str): |
|
|
"""Validate model weights for security compliance.""" |
|
|
try: |
|
|
result = validate_model_weights_security(model_path) |
|
|
return result |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Security validation failed: {str(e)}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |