File size: 12,975 Bytes
a8fc815
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
"""
Production API Endpoint
Demonstrates complete Transformers + Safetensors integration with tier management
"""

from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any
import logging
import uuid
from datetime import datetime
import asyncio

# Import our modules
from core.scene_planner import get_planner, ScenePlanner
from models.image.sd_generator import get_generator, SafeStableDiffusionGenerator
from config.model_tiers import get_tier_config, validate_model_weights_security

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize FastAPI app
app = FastAPI(
    title="Memo API - Transformers + Safetensors",
    description="Production-grade video generation API with proper ML security",
    version="2.0.0"
)

# Request/Response Models
class VideoGenerationRequest(BaseModel):
    text: str = Field(..., description="Bangla text content")
    duration: int = Field(15, ge=5, le=60, description="Video duration in seconds")
    tier: str = Field("free", description="Model tier (free, pro, enterprise)")
    style: Optional[str] = Field(None, description="Visual style preference")
    
    class Config:
        schema_extra = {
            "example": {
                "text": "আজকের দিনটি খুব সুন্দর ছিল। রোদ উজ্জ্বল এবং হাওয়া মৃদুমন্দ।",
                "duration": 15,
                "tier": "pro",
                "style": "realistic"
            }
        }

class SceneModel(BaseModel):
    id: int
    description: str
    duration: float
    start_time: float
    end_time: float
    visual_style: str
    transition_type: str

class GenerationStatus(BaseModel):
    request_id: str
    status: str  # "pending", "processing", "completed", "failed"
    progress: float = Field(0.0, ge=0.0, le=100.0)
    message: Optional[str] = None
    scenes: Optional[List[SceneModel]] = None
    created_at: datetime
    updated_at: datetime

class VideoGenerationResponse(BaseModel):
    request_id: str
    status: str
    message: str
    tier_used: str
    scenes_count: int
    estimated_duration: float
    credits_used: float
    security_compliant: bool

# Global state management
generation_status = {}
tier_managers = {}

# Initialize tier managers
def initialize_tier_managers():
    """Initialize model managers for different tiers."""
    tiers = ["free", "pro", "enterprise"]
    
    for tier_name in tiers:
        try:
            tier_config = get_tier_config(tier_name)
            if tier_config:
                logger.info(f"Initializing {tier_name} tier...")
                
                # Initialize scene planner
                scene_planner = ScenePlanner(tier_config.text_model_id)
                
                # Initialize image generator
                image_generator = SafeStableDiffusionGenerator(
                    model_id=tier_config.image_model_id,
                    lora_path=tier_config.lora_path,
                    use_lcm=tier_config.lcm_enabled
                )
                
                tier_managers[tier_name] = {
                    "scene_planner": scene_planner,
                    "image_generator": image_generator,
                    "config": tier_config
                }
                
                logger.info(f"{tier_name} tier initialized successfully")
            else:
                logger.warning(f"No configuration found for tier: {tier_name}")
                
        except Exception as e:
            logger.error(f"Failed to initialize {tier_name} tier: {e}")

# Background processing
async def process_video_generation(request_id: str, request: VideoGenerationRequest):
    """Background task for video generation."""
    try:
        status = generation_status[request_id]
        status.status = "processing"
        status.progress = 10.0
        status.message = "Initializing models..."
        status.updated_at = datetime.now()
        
        # Get tier configuration
        tier_config = get_tier_config(request.tier)
        if not tier_config:
            raise ValueError(f"Invalid tier: {request.tier}")
        
        tier_manager = tier_managers.get(request.tier)
        if not tier_manager:
            raise ValueError(f"Tier manager not available: {request.tier}")
        
        status.progress = 20.0
        status.message = "Planning scenes..."
        
        # Step 1: Plan scenes using transformer model
        scenes = tier_manager["scene_planner"].plan_scenes(
            text_bn=request.text,
            duration=request.duration
        )
        
        status.scenes = [SceneModel(**scene) for scene in scenes]
        status.progress = 40.0
        status.message = "Generating frames..."
        
        # Step 2: Generate images using Stable Diffusion + Safetensors
        generated_frames = []
        for i, scene in enumerate(scenes):
            status.message = f"Generating frame {i+1}/{len(scenes)}..."
            status.progress = 40.0 + (30.0 * (i + 1) / len(scenes))
            
            # Generate frame with appropriate settings
            frames = tier_manager["image_generator"].generate_frames(
                prompt=scene["description"],
                frames=1,  # Generate one frame per scene
                width=tier_config.image_width,
                height=tier_config.image_height,
                num_inference_steps=tier_config.image_inference_steps,
                guidance_scale=tier_config.image_guidance_scale
            )
            
            if frames:
                generated_frames.extend(frames)
            
            # Small delay to prevent overwhelming the system
            await asyncio.sleep(0.1)
        
        status.progress = 80.0
        status.message = "Finalizing generation..."
        
        # Step 3: Security validation
        security_results = []
        if tier_config.lora_path:
            security_result = validate_model_weights_security(tier_config.lora_path)
            security_results.append(security_result)
        
        # Finalize
        status.status = "completed"
        status.progress = 100.0
        status.message = f"Generated {len(generated_frames)} frames successfully"
        status.updated_at = datetime.now()
        
        logger.info(f"Video generation completed for request {request_id}")
        
    except Exception as e:
        logger.error(f"Video generation failed for request {request_id}: {e}")
        status = generation_status[request_id]
        status.status = "failed"
        status.message = f"Generation failed: {str(e)}"
        status.updated_at = datetime.now()

# API Endpoints

@app.on_event("startup")
async def startup_event():
    """Initialize the application."""
    logger.info("Starting Memo API with Transformers + Safetensors")
    initialize_tier_managers()
    logger.info("Application initialized successfully")

@app.get("/health")
async def health_check():
    """Health check endpoint."""
    return {
        "status": "healthy",
        "version": "2.0.0",
        "transformers_version": "4.40.0+",
        "safetensors_enabled": True,
        "available_tiers": list(tier_managers.keys())
    }

@app.get("/tiers")
async def list_tiers():
    """List available model tiers."""
    return {
        "tiers": [
            {
                "name": tier_name,
                "config": {
                    "description": manager["config"].description,
                    "max_scenes": manager["config"].text_max_scenes,
                    "image_resolution": f"{manager['config'].image_width}x{manager['config'].image_height}",
                    "lora_enabled": manager["config"].lora_path is not None,
                    "lcm_enabled": manager["config"].lcm_enabled,
                    "credits_per_minute": manager["config"].credits_per_minute
                }
            }
            for tier_name, manager in tier_managers.items()
        ]
    }

@app.post("/generate", response_model=VideoGenerationResponse)
async def generate_video(
    request: VideoGenerationRequest,
    background_tasks: BackgroundTasks
):
    """
    Generate video content using transformer models and safetensors.
    
    This endpoint demonstrates the complete integration:
    - Bangla text parsing using Transformers
    - Scene planning with ML-based logic
    - Image generation with Stable Diffusion + Safetensors
    - Proper security validation
    - Tier-based resource management
    """
    try:
        # Validate request
        if not request.text.strip():
            raise HTTPException(status_code=400, detail="Text content cannot be empty")
        
        tier_config = get_tier_config(request.tier)
        if not tier_config:
            raise HTTPException(status_code=400, detail=f"Invalid tier: {request.tier}")
        
        tier_manager = tier_managers.get(request.tier)
        if not tier_manager:
            raise HTTPException(status_code=500, detail=f"Tier {request.tier} not available")
        
        # Create request ID
        request_id = str(uuid.uuid4())
        
        # Initialize status tracking
        generation_status[request_id] = GenerationStatus(
            request_id=request_id,
            status="pending",
            created_at=datetime.now(),
            updated_at=datetime.now()
        )
        
        # Start background processing
        background_tasks.add_task(process_video_generation, request_id, request)
        
        # Calculate estimated costs
        estimated_duration = request.duration
        credits_used = (estimated_duration / 60.0) * tier_config.credits_per_minute
        
        # Security compliance check
        security_compliant = True
        if tier_config.lora_path:
            security_result = validate_model_weights_security(tier_config.lora_path)
            security_compliant = security_result["is_secure"]
        
        response = VideoGenerationResponse(
            request_id=request_id,
            status="processing",
            message="Video generation started",
            tier_used=request.tier,
            scenes_count=tier_config.text_max_scenes,
            estimated_duration=estimated_duration,
            credits_used=credits_used,
            security_compliant=security_compliant
        )
        
        logger.info(f"Video generation started for request {request_id} (tier: {request.tier})")
        return response
        
    except HTTPException:
        raise
    except Exception as e:
        logger.error(f"Failed to start video generation: {e}")
        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")

@app.get("/status/{request_id}", response_model=GenerationStatus)
async def get_generation_status(request_id: str):
    """Get the status of a video generation request."""
    if request_id not in generation_status:
        raise HTTPException(status_code=404, detail="Request not found")
    
    return generation_status[request_id]

@app.get("/models/info")
async def get_models_info():
    """Get information about loaded models."""
    models_info = {}
    
    for tier_name, manager in tier_managers.items():
        try:
            scene_planner = manager["scene_planner"]
            image_generator = manager["image_generator"]
            config = manager["config"]
            
            models_info[tier_name] = {
                "text_model": {
                    "model_id": config.text_model_id,
                    "max_scenes": config.text_max_scenes,
                    "device": scene_planner.parser.device
                },
                "image_model": {
                    "model_id": config.image_model_id,
                    "resolution": f"{config.image_width}x{config.image_height}",
                    "inference_steps": config.image_inference_steps,
                    "lora_path": config.lora_path,
                    "lcm_enabled": config.lcm_enabled
                },
                "security": {
                    "safetensors_only": config.safetensors_only,
                    "model_signatures_required": config.model_signatures_required
                }
            }
        except Exception as e:
            models_info[tier_name] = {"error": str(e)}
    
    return {"models": models_info}

@app.post("/security/validate")
async def validate_security(model_path: str):
    """Validate model weights for security compliance."""
    try:
        result = validate_model_weights_security(model_path)
        return result
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Security validation failed: {str(e)}")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)