Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -25,6 +25,17 @@ login(token=hf_token)
|
|
| 25 |
|
| 26 |
MAX_SEED = np.iinfo(np.int32).max
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
| 29 |
if randomize_seed:
|
| 30 |
seed = random.randint(0, MAX_SEED)
|
|
@@ -67,10 +78,9 @@ import PIL.Image as Image
|
|
| 67 |
|
| 68 |
base_model = 'briaai/BRIA-4B-Adapt'
|
| 69 |
controlnet_model = 'briaai/BRIA-4B-Adapt-ControlNet-Union'
|
| 70 |
-
|
| 71 |
controlnet = BriaControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
|
| 72 |
pipe = BriaControlNetPipeline.from_pretrained(base_model, controlnet=controlnet, torch_dtype=torch.bfloat16, trust_remote_code=True)
|
| 73 |
-
pipe.to("cuda")
|
| 74 |
|
| 75 |
mode_mapping = {
|
| 76 |
"depth": 0,
|
|
@@ -172,6 +182,7 @@ def infer(cond_in, image_in, prompt, inference_steps, guidance_scale, control_mo
|
|
| 172 |
guidance_scale=guidance_scale,
|
| 173 |
generator=torch.manual_seed(seed),
|
| 174 |
max_sequence_length=128,
|
|
|
|
| 175 |
).images[0]
|
| 176 |
|
| 177 |
torch.cuda.empty_cache()
|
|
|
|
| 25 |
|
| 26 |
MAX_SEED = np.iinfo(np.int32).max
|
| 27 |
|
| 28 |
+
try:
|
| 29 |
+
local_dir = os.path.dirname(__file__)
|
| 30 |
+
except:
|
| 31 |
+
local_dir = '.'
|
| 32 |
+
|
| 33 |
+
hf_hub_download(repo_id="briaai/BRIA-4B-Adapt", filename='pipeline_bria.py', local_dir=local_dir)
|
| 34 |
+
hf_hub_download(repo_id="briaai/BRIA-4B-Adapt", filename='transformer_bria.py', local_dir=local_dir)
|
| 35 |
+
hf_hub_download(repo_id="briaai/BRIA-4B-Adapt", filename='bria_utils.py', local_dir=local_dir)
|
| 36 |
+
hf_hub_download(repo_id="briaai/BRIA-4B-Adapt-ControlNet-Union", filename='pipeline_bria_controlnet.py', local_dir=local_dir)
|
| 37 |
+
hf_hub_download(repo_id="briaai/BRIA-4B-Adapt-ControlNet-Union", filename='controlnet_bria.py', local_dir=local_dir)
|
| 38 |
+
|
| 39 |
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
| 40 |
if randomize_seed:
|
| 41 |
seed = random.randint(0, MAX_SEED)
|
|
|
|
| 78 |
|
| 79 |
base_model = 'briaai/BRIA-4B-Adapt'
|
| 80 |
controlnet_model = 'briaai/BRIA-4B-Adapt-ControlNet-Union'
|
|
|
|
| 81 |
controlnet = BriaControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
|
| 82 |
pipe = BriaControlNetPipeline.from_pretrained(base_model, controlnet=controlnet, torch_dtype=torch.bfloat16, trust_remote_code=True)
|
| 83 |
+
pipe = pipeline.to(device="cuda", dtype=torch.bfloat16)
|
| 84 |
|
| 85 |
mode_mapping = {
|
| 86 |
"depth": 0,
|
|
|
|
| 182 |
guidance_scale=guidance_scale,
|
| 183 |
generator=torch.manual_seed(seed),
|
| 184 |
max_sequence_length=128,
|
| 185 |
+
negative_prompt="Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate"
|
| 186 |
).images[0]
|
| 187 |
|
| 188 |
torch.cuda.empty_cache()
|