Cabo3616
/

File size: 4,785 Bytes
fa2b402
 
 
7b24fac
fa2b402
 
6c94d7b
fa2b402
6c94d7b
fa2b402
7b24fac
fa2b402
 
 
 
 
6c94d7b
7b24fac
fa2b402
 
7b24fac
 
 
fa2b402
7b24fac
 
 
 
 
 
 
 
 
 
 
 
fa2b402
7b24fac
 
 
 
 
 
 
 
 
 
 
045ce0f
 
7b24fac
 
 
 
 
 
 
 
fa2b402
7b24fac
 
 
 
 
 
fa2b402
 
 
 
 
 
 
 
 
 
 
6c94d7b
 
fa2b402
 
6c94d7b
 
 
 
 
 
 
 
 
 
 
fa2b402
 
6c94d7b
fa2b402
 
 
ed0a041
fa2b402
 
 
 
 
 
 
 
 
 
6c94d7b
fa2b402
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c94d7b
 
 
 
 
 
 
 
fa2b402
 
7b24fac
fa2b402
 
 
 
 
 
ed0a041
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""Custom handler para HF Inference Endpoints — LTX-Video.



Recibe: {"inputs": "prompt", "parameters": {...}}

Retorna: video .mp4 como bytes binarios (via ffmpeg subprocess)

"""

import base64
import gc
import io
import os
import subprocess
import tempfile
from typing import Any, Dict

import numpy as np
import torch
from diffusers import LTXPipeline, LTXImageToVideoPipeline
from PIL import Image


def _frames_to_mp4(frames, fps: int = 16) -> bytes:
    """Convierte frames a mp4 usando ffmpeg directamente (sin opencv/imageio)."""
    tmpdir = tempfile.mkdtemp()
    
    # Guardar frames como PNGs
    for i, frame in enumerate(frames):
        if isinstance(frame, Image.Image):
            img = frame
        elif hasattr(frame, 'numpy'):
            arr = frame.numpy()
            if arr.dtype in (np.float32, np.float64):
                arr = (arr * 255).clip(0, 255).astype(np.uint8)
            img = Image.fromarray(arr)
        else:
            arr = np.array(frame)
            if arr.dtype in (np.float32, np.float64):
                arr = (arr * 255).clip(0, 255).astype(np.uint8)
            img = Image.fromarray(arr)
        img.save(os.path.join(tmpdir, f"frame_{i:05d}.png"))
    
    # Usar ffmpeg para crear mp4
    out_path = os.path.join(tmpdir, "output.mp4")
    cmd = [
        "ffmpeg", "-y",
        "-framerate", str(fps),
        "-i", os.path.join(tmpdir, "frame_%05d.png"),
        "-c:v", "libx264",
        "-pix_fmt", "yuv420p",
        "-crf", "18",
        "-preset", "medium",
        out_path
    ]
    
    result = subprocess.run(cmd, capture_output=True, timeout=120)
    if result.returncode != 0:
        raise RuntimeError(f"ffmpeg failed: {result.stderr.decode()[:500]}")
    
    with open(out_path, "rb") as f:
        video_bytes = f.read()
    
    # Limpiar archivos temporales
    for fname in os.listdir(tmpdir):
        os.unlink(os.path.join(tmpdir, fname))
    os.rmdir(tmpdir)
    
    return video_bytes



class EndpointHandler:
    """Handler personalizado para LTX-Video en HF Inference Endpoints."""

    def __init__(self, path: str = ""):
        device = "cuda" if torch.cuda.is_available() else "cpu"
        dtype = torch.float16 if device == "cuda" else torch.float32

        # T2V pipeline
        self.pipe_t2v = LTXPipeline.from_pretrained(
            path, torch_dtype=dtype, local_files_only=True
        )
        self.pipe_t2v = self.pipe_t2v.to(device)

        # I2V pipeline sharing components with T2V to save VRAM
        self.pipe_i2v = LTXImageToVideoPipeline(
            vae=self.pipe_t2v.vae,
            text_encoder=self.pipe_t2v.text_encoder,
            tokenizer=self.pipe_t2v.tokenizer,
            transformer=self.pipe_t2v.transformer,
            scheduler=self.pipe_t2v.scheduler,
        )
        self.pipe_i2v = self.pipe_i2v.to(device)

        if device == "cuda":
            self.pipe_t2v.vae.enable_tiling()

        self.device = device

    def __call__(self, data: Dict[str, Any]) -> list:
        prompt = data.get("inputs", "")
        params = data.get("parameters", {})

        num_frames = params.get("num_frames", 81)
        guidance_scale = params.get("guidance_scale", 5.0)
        num_inference_steps = params.get("num_inference_steps", 30)
        negative_prompt = params.get("negative_prompt", None)
        seed = params.get("seed", None)
        height = params.get("height", 512)
        width = params.get("width", 704)
        image_b64 = params.get("image", None)

        generator = None
        if seed is not None:
            generator = torch.Generator(device=self.device).manual_seed(seed)

        gen_kwargs = {
            "prompt": prompt,
            "num_frames": num_frames,
            "height": height,
            "width": width,
            "guidance_scale": guidance_scale,
            "num_inference_steps": num_inference_steps,
            "generator": generator,
        }
        if negative_prompt:
            gen_kwargs["negative_prompt"] = negative_prompt

        # I2V mode: decode base64 image and use I2V pipeline
        if image_b64:
            image = Image.open(io.BytesIO(base64.b64decode(image_b64))).convert("RGB")
            gen_kwargs["image"] = image
            result = self.pipe_i2v(**gen_kwargs)
        else:
            result = self.pipe_t2v(**gen_kwargs)

        frames = result.frames[0]

        video_bytes = _frames_to_mp4(frames, fps=16)

        del frames, result
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        return [{"generated_video": base64.b64encode(video_bytes).decode("utf-8")}]