Clean handler.py
Browse files- handler.py +0 -34
handler.py
CHANGED
|
@@ -1,43 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
import base64
|
| 3 |
import io
|
| 4 |
-
import sys
|
| 5 |
from typing import Dict, Any
|
| 6 |
from PIL import Image
|
| 7 |
|
| 8 |
-
# Create a fake cosmos_guardrail module BEFORE any diffusers import
|
| 9 |
-
class FakeCosmosSafetyChecker:
|
| 10 |
-
def __init__(self):
|
| 11 |
-
pass
|
| 12 |
-
|
| 13 |
-
def __call__(self, frames, **kwargs):
|
| 14 |
-
return frames
|
| 15 |
-
|
| 16 |
-
def check_text_safety(self, text):
|
| 17 |
-
return True
|
| 18 |
-
|
| 19 |
-
def check_video_safety(self, frames):
|
| 20 |
-
return frames
|
| 21 |
-
|
| 22 |
-
# Inject fake module into sys.modules
|
| 23 |
-
import types
|
| 24 |
-
fake_guardrail = types.ModuleType("cosmos_guardrail")
|
| 25 |
-
fake_guardrail.CosmosSafetyChecker = FakeCosmosSafetyChecker
|
| 26 |
-
sys.modules["cosmos_guardrail"] = fake_guardrail
|
| 27 |
-
|
| 28 |
-
# Now patch diffusers to think cosmos_guardrail is available
|
| 29 |
-
import diffusers.utils.import_utils as import_utils
|
| 30 |
-
original_is_available = getattr(import_utils, 'is_cosmos_guardrail_available', lambda: False)
|
| 31 |
-
import_utils.is_cosmos_guardrail_available = lambda: True
|
| 32 |
-
|
| 33 |
-
# Also patch at pipeline level if needed
|
| 34 |
-
try:
|
| 35 |
-
import diffusers.pipelines.cosmos.pipeline_cosmos2_video2world as cosmos_pipeline
|
| 36 |
-
cosmos_pipeline.is_cosmos_guardrail_available = lambda: True
|
| 37 |
-
cosmos_pipeline.CosmosSafetyChecker = FakeCosmosSafetyChecker
|
| 38 |
-
except:
|
| 39 |
-
pass
|
| 40 |
-
|
| 41 |
|
| 42 |
class EndpointHandler:
|
| 43 |
def __init__(self, path: str = ""):
|
|
|
|
| 1 |
import torch
|
| 2 |
import base64
|
| 3 |
import io
|
|
|
|
| 4 |
from typing import Dict, Any
|
| 5 |
from PIL import Image
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
class EndpointHandler:
|
| 9 |
def __init__(self, path: str = ""):
|