Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -33,7 +33,7 @@ from PIL import Image
|
|
| 33 |
from torchvision.utils import make_grid, save_image
|
| 34 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 35 |
|
| 36 |
-
from app import safety_check
|
| 37 |
from app.sana_pipeline import SanaPipeline
|
| 38 |
|
| 39 |
MAX_SEED = np.iinfo(np.int32).max
|
|
@@ -205,12 +205,12 @@ if torch.cuda.is_available():
|
|
| 205 |
pipe.register_progress_bar(gr.Progress())
|
| 206 |
|
| 207 |
# safety checker
|
| 208 |
-
safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path)
|
| 209 |
-
safety_checker_model = AutoModelForCausalLM.from_pretrained(
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
).to(device)
|
| 214 |
|
| 215 |
|
| 216 |
def save_image_sana(img, seed="", save_img=False):
|
|
@@ -254,8 +254,8 @@ def generate(
|
|
| 254 |
seed = int(randomize_seed_fn(seed, randomize_seed))
|
| 255 |
generator = torch.Generator(device=device).manual_seed(seed)
|
| 256 |
print(f"PORT: {DEMO_PORT}, model_path: {model_path}")
|
| 257 |
-
if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt, threshold=0.2):
|
| 258 |
-
|
| 259 |
|
| 260 |
print(prompt)
|
| 261 |
|
|
|
|
| 33 |
from torchvision.utils import make_grid, save_image
|
| 34 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 35 |
|
| 36 |
+
#from app import safety_check
|
| 37 |
from app.sana_pipeline import SanaPipeline
|
| 38 |
|
| 39 |
MAX_SEED = np.iinfo(np.int32).max
|
|
|
|
| 205 |
pipe.register_progress_bar(gr.Progress())
|
| 206 |
|
| 207 |
# safety checker
|
| 208 |
+
#safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path)
|
| 209 |
+
#safety_checker_model = AutoModelForCausalLM.from_pretrained(
|
| 210 |
+
# args.shield_model_path,
|
| 211 |
+
# device_map="auto",
|
| 212 |
+
# torch_dtype=torch.bfloat16,
|
| 213 |
+
#).to(device)
|
| 214 |
|
| 215 |
|
| 216 |
def save_image_sana(img, seed="", save_img=False):
|
|
|
|
| 254 |
seed = int(randomize_seed_fn(seed, randomize_seed))
|
| 255 |
generator = torch.Generator(device=device).manual_seed(seed)
|
| 256 |
print(f"PORT: {DEMO_PORT}, model_path: {model_path}")
|
| 257 |
+
#if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt, threshold=0.2):
|
| 258 |
+
# prompt = "A red heart."
|
| 259 |
|
| 260 |
print(prompt)
|
| 261 |
|