Varhal commited on
Commit
131db93
·
verified ·
1 Parent(s): 557a31a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +160 -193
app.py CHANGED
@@ -1,12 +1,11 @@
1
  # Configuration
2
- prod = False # This variable is no longer used for launching, but kept for potential future use
3
- port = 8080 # This variable is no longer used for launching, but kept for potential future use
4
- show_options = False # This variable is no longer used for UI visibility
5
 
6
  import os
7
  import random
8
  import time
9
- # Removed gradio import as UI is being removed
10
  import numpy as np
11
  import spaces
12
  import imageio
@@ -19,20 +18,27 @@ from diffusers import (
19
  ControlNetModel,
20
  DPMSolverMultistepScheduler,
21
  StableDiffusionControlNetPipeline,
22
- # StableDiffusionInpaintPipeline, # Commented out as inpainting part was commented
23
- # AutoencoderKL, # Commented out as VAE part was commented
24
  )
25
  # Assuming controlnet_aux_local is a local package or needs to be installed separately
26
  from controlnet_aux_local import NormalBaeDetector
27
 
 
 
 
 
 
 
 
28
  MAX_SEED = np.iinfo(np.int32).max
29
  API_KEY = os.environ.get("API_KEY", None)
30
  # os.environ['HF_HOME'] = '/data/.huggingface'
31
  print("CUDA version:", torch.version.cuda)
32
  print("loading everything")
33
- compiled = False # This variable is no longer explicitly set to True after compilation print
34
  api = HfApi()
35
 
 
 
36
 
37
  class Preprocessor:
38
  MODEL_ID = "lllyasviel/Annotators"
@@ -46,8 +52,8 @@ class Preprocessor:
46
  return
47
  elif name == "NormalBae":
48
  print("Loading NormalBae")
49
- # Ensure model is moved to cuda if available
50
- self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID).to("cuda" if torch.cuda.is_available() else "cpu")
51
  if torch.cuda.is_available():
52
  torch.cuda.empty_cache()
53
  self.name = name
@@ -56,24 +62,22 @@ class Preprocessor:
56
  return
57
 
58
  def __call__(self, image: Image.Image, **kwargs) -> Image.Image:
59
- # Ensure model is on the correct device before calling
60
  device = "cuda" if torch.cuda.is_available() else "cpu"
61
  if self.model.device.type != device:
62
  self.model.to(device)
63
  return self.model(image, **kwargs)
64
 
65
- # Load models and preprocessor directly without gr.NO_RELOAD check
66
- # This block will execute when the script is imported or run
67
  # Controlnet Normal
68
- model_id = "lllyasviel/control_v11p_sd15_normalbae"
69
  print("initializing controlnet")
70
- # Ensure models are loaded onto the correct device
71
  device = "cuda" if torch.cuda.is_available() else "cpu"
72
  controlnet = ControlNetModel.from_pretrained(
73
  model_id,
74
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, # Use float32 if CUDA is not available
75
- attn_implementation="flash_attention_2" if torch.cuda.is_available() else None, # Flash attention only for CUDA
76
  ).to(device)
 
77
  # Scheduler
78
  scheduler = DPMSolverMultistepScheduler.from_pretrained(
79
  "ashllay/stable-diffusion-v1-5-archive",
@@ -85,21 +89,19 @@ scheduler = DPMSolverMultistepScheduler.from_pretrained(
85
  prediction_type="epsilon",
86
  thresholding=False,
87
  denoise_final=True,
88
- # device_map="cuda", # device_map can sometimes cause issues, better to move after loading
89
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, # Use float32 if CUDA is not available
90
  )
91
- # Removed this line as Schedulers don't have a .to() method
92
- # scheduler.to(device)
93
 
94
  # Stable Diffusion Pipeline URL
95
  base_model_url = "https://huggingface.co/Lykon/AbsoluteReality/blob/main/AbsoluteReality_1.8.1_pruned.safetensors"
96
  print('loading pipe')
97
  pipe = StableDiffusionControlNetPipeline.from_single_file(
98
  base_model_url,
99
- safety_checker=None,
100
  controlnet=controlnet,
101
  scheduler=scheduler,
102
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, # Use float32 if CUDA is not available
103
  ).to(device)
104
 
105
  print("loading preprocessor")
@@ -108,68 +110,37 @@ preprocessor.load("NormalBae") # Preprocessor is loaded here
108
 
109
  # Load textual inversions
110
  try:
111
- pipe.load_textual_inversion(
112
- "broyang/hentaidigitalart_v20", weight_name="EasyNegativeV2.safetensors", token="EasyNegativeV2",
113
- )
114
- pipe.load_textual_inversion(
115
- "broyang/hentaidigitalart_v20", weight_name="badhandv4.pt", token="badhandv4"
116
- )
117
- pipe.load_textual_inversion(
118
- "broyang/hentaidigitalart_v20", weight_name="fcNeg-neg.pt", token="fcNeg-neg"
119
- )
120
- pipe.load_textual_inversion(
121
- "broyang/hentaidigitalart_v20", weight_name="HDA_Ahegao.pt", token="HDA_Ahegao"
122
- )
123
- pipe.load_textual_inversion(
124
- "broyang/hentaidigitalart_v20", weight_name="HDA_Bondage.pt", token="HDA_Bondage"
125
- )
126
- pipe.load_textual_inversion(
127
- "broyang/hentaidigitalart_v20", weight_name="HDA_pet_play.pt", token="HDA_pet_play"
128
- )
129
- pipe.load_textual_inversion(
130
- "broyang/hentaidigitalart_v20",
131
- weight_name="HDA_unconventional maid.pt",
132
- token="HDA_unconventional_maid",
133
- )
134
- pipe.load_textual_inversion(
135
- "broyang/hentaidigitalart_v20", weight_name="HDA_NakedHoodie.pt", token="HDA_NakedHoodie"
136
- )
137
- pipe.load_textual_inversion(
138
- "broyang/hentaidigitalart_v20", weight_name="HDA_NunDress.pt", token="HDA_NunDress"
139
- )
140
- pipe.load_textual_inversion(
141
- "broyang/hentaidigitalart_v20", weight_name="HDA_Shibari.pt", token="HDA_Shibari"
142
- )
143
  except Exception as e:
144
- print(f"Error loading textual inversions: {e}")
145
- # Handle cases where loading textual inversions might fail, e.g., file not found
146
 
147
  print("---------------Loaded controlnet pipeline---------------")
148
  if torch.cuda.is_available():
149
  torch.cuda.empty_cache()
150
  gc.collect()
151
  print(f"CUDA memory allocated: {torch.cuda.max_memory_allocated(device='cuda') / 1e9:.2f} GB")
152
- # Removed "Model Compiled!" print as compilation is not explicitly handled here
153
-
154
-
155
- # Removed generate_furniture_mask as inpainting part was commented out
156
- # def generate_furniture_mask(image, furniture_type):
157
- # image_np = np.array(image)
158
- # height, width = image_np.shape[:2]
159
- # mask = np.zeros((height, width), dtype=np.uint8)
160
- # if furniture_type == "sofa":
161
- # cv2.rectangle(mask, (width // 4, int(height * 0.6)), (width * 3 // 4, height), 255, -1)
162
- # elif furniture_type == "table":
163
- # cv2.rectangle(mask, (width // 3, height // 3), (width * 2 // 3, height * 2 // 3), 255, -1)
164
- # elif furniture_type == "chair":
165
- # cv2.circle(mask, (width * 3 // 5, height * 2 // 3), height // 6, 255, -1)
166
- # return Image.fromarray(mask)
167
-
168
- # Removed randomize_seed_fn as the logic is directly in process_image
169
- # def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
170
- # if randomize_seed:
171
- # seed = random.randint(0, MAX_SEED)
172
- # return seed
173
 
174
 
175
  def get_additional_prompt():
@@ -184,51 +155,29 @@ def get_additional_prompt():
184
 
185
  def get_prompt(prompt, additional_prompt):
186
  interior = "design-style interior designed (interior space),tungsten white balance,captured with a DSLR camera using f/10 aperture, 1/60 sec shutter speed, ISO 400, 20mm focal length"
187
- # default = "hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed" # Not used
188
- # default2 = f"professional 3d model {prompt},octane render,highly detailed,volumetric,dramatic lighting,hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed" # Not used
189
- randomize = get_additional_prompt()
190
- # nude = "NSFW,((nude)),medium bare breasts,hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed" # Not used
191
- # bodypaint = "((fully naked with no clothes)),nude naked seethroughxray,invisiblebodypaint,rating_newd,NSFW" # Not used
192
- lab_girl = "hyperrealistic photography, extremely detailed, shy assistant wearing minidress boots and gloves, laboratory background, score_9, 1girl"
193
- pet_play = "hyperrealistic photography, extremely detailed, playful, blush, glasses, collar, score_9, HDA_pet_play"
194
- bondage = "hyperrealistic photography, extremely detailed, submissive, glasses, score_9, HDA_Bondage"
195
- # ahegao = "((invisible clothing)), hyperrealistic photography,exposed vagina,sexy,nsfw,HDA_Ahegao" # Not used
196
- ahegao2 = "(invisiblebodypaint),rating_newd,HDA_Ahegao"
197
- athleisure = "hyperrealistic photography, extremely detailed, 1girl athlete, exhausted embarrassed sweaty,outdoors, ((athleisure clothing)), score_9"
198
- atompunk = "((atompunk world)), hyperrealistic photography, extremely detailed, short hair, bodysuit, glasses, neon cyberpunk background, score_9"
199
- maid = "hyperrealistic photography, extremely detailed, shy, blushing, score_9, pastel background, HDA_unconventional_maid"
200
- nundress = "hyperrealistic photography, extremely detailed, shy, blushing, fantasy background, score_9, HDA_NunDress"
201
- naked_hoodie = "hyperrealistic photography, extremely detailed, medium hair, cityscape, (neon lights), score_9, HDA_NakedHoodie"
202
- abg = "(1girl, asian body covered in words, words on body, tattoos of (words) on body),(masterpiece, best quality),medium breasts,(intricate details),unity 8k wallpaper,ultra detailed,(pastel colors),beautiful and aesthetic,see-through (clothes),detailed,solo"
203
- # shibari = "extremely detailed, hyperrealistic photography, earrings, blushing, lace choker, tattoo, medium hair, score_9, HDA_Shibari" # Not used
204
- shibari2 = "octane render, highly detailed, volumetric, HDA_Shibari"
205
-
206
- if prompt == "":
207
- # This block seems to generate prompts for 'girls' which might not be relevant for interior design API
208
- # Consider if this random girl prompt generation is needed for the interior design API
209
- girls = [randomize, pet_play, bondage, lab_girl, athleisure, atompunk, maid, nundress,
210
- naked_hoodie, abg, shibari2, ahegao2]
211
- # prompts_nsfw = [abg, shibari2, ahegao2] # Not used
212
- # prompt = f"{random.choice(girls)}" # This line would overwrite the input prompt
213
- prompt = f"boho chic" # This line also overwrites the input prompt
214
- # The logic here seems inconsistent with using an input 'prompt'.
215
- # Assuming the intention is to use the input 'prompt' for custom designs,
216
- # and apply a style or default interior context.
217
- # Let's revise this to prioritize the input prompt.
218
- if additional_prompt:
219
- # Combine input prompt with additional prompt
220
- return f"{prompt}, {additional_prompt}"
221
- else:
222
- # If no additional prompt, just use the input prompt with interior context
223
- return f"Photo from Pinterest of {prompt} {interior}"
224
  else:
225
- # If a prompt is provided, use it with the interior context
226
- # The original logic here was redundant with the 'if prompt == ""' block
227
- # Let's simplify based on whether a prompt is provided
228
- if additional_prompt:
229
- return f"Photo from Pinterest of {prompt} {interior}, {additional_prompt}"
230
- else:
231
- return f"Photo from Pinterest of {prompt} {interior}"
 
 
 
 
 
 
 
 
 
 
232
 
233
 
234
  style_list = [
@@ -285,31 +234,22 @@ STYLE_NAMES = list(styles.keys())
285
 
286
 
287
  def apply_style(style_name):
288
- # Ensure style_name exists in styles dictionary
289
- return styles.get(style_name, "") # Return empty string if style not found
290
-
291
 
292
- # Removed css variable as it was for Gradio UI
293
- # css = """..."""
294
 
295
- # Removed gr.Blocks context manager and everything inside it
296
-
297
-
298
- # Modified process_image to be a standalone function callable by an API endpoint
299
- # Removed @spaces.GPU and @torch.inference_mode decorators if the API framework handles this
300
- # Added type hints for clarity
301
  def process_image_api(
302
  image: Image.Image,
303
  style_selection: str = "None",
304
  prompt: str = "",
305
  a_prompt: str = "",
306
  n_prompt: str = "EasyNegativeV2, fcNeg, (badhandv4:1.4), (worst quality, low quality, bad quality, normal quality:2.0), (bad hands, missing fingers, extra fingers:2.0)",
307
- num_images: int = 1, # Kept for potential future use, but pipeline currently generates 1
308
  image_resolution: int = 512,
309
  preprocess_resolution: int = 512,
310
  num_steps: int = 15,
311
  guidance_scale: float = 5.5,
312
- seed: int = -1, # Use -1 to indicate random seed if not provided
313
  ):
314
  """
315
  Processes an input image to generate a new image based on style and prompts.
@@ -320,7 +260,6 @@ def process_image_api(
320
  prompt: Custom design prompt.
321
  a_prompt: Additional positive prompt.
322
  n_prompt: Negative prompt.
323
- num_images: Number of images to generate (currently only 1 supported by pipeline).
324
  image_resolution: Resolution for the output image.
325
  preprocess_resolution: Resolution for the preprocessor.
326
  num_steps: Number of inference steps.
@@ -330,18 +269,14 @@ def process_image_api(
330
  Returns:
331
  A PIL Image of the generated result.
332
  """
333
- # Use provided seed or generate a random one
334
  current_seed = seed if seed != -1 else random.randint(0, MAX_SEED)
335
  generator = torch.cuda.manual_seed(current_seed) if torch.cuda.is_available() else torch.manual_seed(current_seed)
336
 
337
- # Ensure preprocessor is loaded
338
  if preprocessor.name != "NormalBae":
339
  preprocessor.load("NormalBae")
340
 
341
- # Ensure preprocessor model is on the correct device
342
  preprocessor.model.to("cuda" if torch.cuda.is_available() else "cpu")
343
 
344
- # Generate control image
345
  control_image = preprocessor(
346
  image=image,
347
  image_resolution=image_resolution,
@@ -350,50 +285,38 @@ def process_image_api(
350
 
351
  # Construct the full prompt
352
  if style_selection and style_selection != "None":
353
- # Apply selected style and combine with custom prompt and additional prompt
354
  style_prompt = apply_style(style_selection)
355
- # Combine prompts, ensuring no empty strings lead to awkward commas
356
  prompt_parts = [f"Photo from Pinterest of {prompt}" if prompt else None, style_prompt if style_prompt else None, a_prompt if a_prompt else None]
357
  full_prompt = ", ".join(filter(None, prompt_parts))
358
  else:
359
- # Use custom prompt and additional prompt with default interior context
360
  full_prompt = get_prompt(prompt, a_prompt)
361
 
 
362
  negative_prompt = str(n_prompt)
363
  print(f"Using prompt: {full_prompt}")
364
  print(f"Using negative prompt: {negative_prompt}")
365
  print(f"Using seed: {current_seed}")
366
 
367
-
368
- # Generate the image using the pipeline
369
- # Ensure the pipeline is on the correct device
370
  pipe.to("cuda" if torch.cuda.is_available() else "cpu")
371
 
372
- with torch.no_grad(): # Use no_grad for inference to save memory and speed
373
  initial_result = pipe(
374
  prompt=full_prompt,
375
  negative_prompt=negative_prompt,
376
  guidance_scale=guidance_scale,
377
- num_images_per_prompt=1, # Pipeline always generates 1 image here
378
  num_inference_steps=num_steps,
379
  generator=generator,
380
  image=control_image,
381
  ).images[0]
382
 
383
- # Save and upload results (optional, depending on API requirements)
384
- # This part might be handled by the API caller or a separate service
385
- # Keeping it for now as it was in the original script
386
  try:
387
  timestamp = int(time.time())
388
- # Saving input image is generally not needed for API response, but keeping for consistency
389
- # img_path = f"{timestamp}_input.jpg"
390
  results_path = f"{timestamp}_output.jpg"
391
- # imageio.imsave(img_path, image) # Removed saving input image
392
  imageio.imsave(results_path, initial_result)
393
 
394
- # Uploading files might not be desired for a general API,
395
- # consider making this optional or removing if the API just returns the image
396
- if API_KEY: # Only attempt upload if API_KEY is available
397
  print(f"Uploading result image to broyang/interior-ai-outputs/{results_path}")
398
  try:
399
  api.upload_file(
@@ -402,17 +325,8 @@ def process_image_api(
402
  repo_id="broyang/interior-ai-outputs",
403
  repo_type="dataset",
404
  token=API_KEY,
405
- run_as_future=True, # Asynchronous upload
406
  )
407
- # Removed input image upload
408
- # api.upload_file(
409
- # path_or_fileobj=img_path,
410
- # path_in_repo=img_path,
411
- # repo_id="broyang/interior-ai-outputs",
412
- # repo_type="dataset",
413
- # token=API_KEY,
414
- # run_as_future=True,
415
- # )
416
  except Exception as e:
417
  print(f"Error uploading file to Hugging Face Hub: {e}")
418
  else:
@@ -421,32 +335,85 @@ def process_image_api(
421
  except Exception as e:
422
  print(f"Error saving or uploading image: {e}")
423
 
424
-
425
  return initial_result
426
 
427
- # The script now defines the process_image_api function.
428
- # To use this as an API, you would typically import this script
429
- # into a web framework like FastAPI and define an endpoint that
430
- # calls process_image_api with the appropriate parameters from the request.
431
-
432
- # Example of how you might call the function (this part is for demonstration,
433
- # you would remove it when integrating into a web framework):
434
- # if __name__ == "__main__":
435
- # # Create a dummy input image (e.g., a black square)
436
- # dummy_image = Image.new('RGB', (512, 512), color = 'red')
437
- # print("Generating a sample image...")
438
- # # Call the processing function with sample parameters
439
- # generated_image = process_image_api(
440
- # image=dummy_image,
441
- # style_selection="Boho",
442
- # prompt="cozy living room",
443
- # a_prompt="warm lighting",
444
- # num_steps=20,
445
- # guidance_scale=7.0,
446
- # seed=42
447
- # )
448
- # # You can now save or display the generated_image
449
- # generated_image.save("sample_output.jpg")
450
- # print("Sample image generated and saved as sample_output.jpg")
451
-
452
- # Removed the demo.queue().launch() calls
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Configuration
2
+ # These variables are now mostly for reference, FastAPI/Uvicorn handle port
3
+ prod = False
4
+ port = int(os.environ.get("PORT", 8080)) # Use PORT environment variable provided by Spaces, default to 8080
5
 
6
  import os
7
  import random
8
  import time
 
9
  import numpy as np
10
  import spaces
11
  import imageio
 
18
  ControlNetModel,
19
  DPMSolverMultistepScheduler,
20
  StableDiffusionControlNetPipeline,
 
 
21
  )
22
  # Assuming controlnet_aux_local is a local package or needs to be installed separately
23
  from controlnet_aux_local import NormalBaeDetector
24
 
25
+ # Import necessary components for FastAPI
26
+ from fastapi import FastAPI, File, UploadFile, Form, HTTPException
27
+ from fastapi.responses import StreamingResponse
28
+ from pydantic import BaseModel
29
+ import uvicorn
30
+ import io
31
+
32
  MAX_SEED = np.iinfo(np.int32).max
33
  API_KEY = os.environ.get("API_KEY", None)
34
  # os.environ['HF_HOME'] = '/data/.huggingface'
35
  print("CUDA version:", torch.version.cuda)
36
  print("loading everything")
37
+ compiled = False
38
  api = HfApi()
39
 
40
+ # Initialize FastAPI app
41
+ app = FastAPI()
42
 
43
  class Preprocessor:
44
  MODEL_ID = "lllyasviel/Annotators"
 
52
  return
53
  elif name == "NormalBae":
54
  print("Loading NormalBae")
55
+ device = "cuda" if torch.cuda.is_available() else "cpu"
56
+ self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID).to(device)
57
  if torch.cuda.is_available():
58
  torch.cuda.empty_cache()
59
  self.name = name
 
62
  return
63
 
64
  def __call__(self, image: Image.Image, **kwargs) -> Image.Image:
 
65
  device = "cuda" if torch.cuda.is_available() else "cpu"
66
  if self.model.device.type != device:
67
  self.model.to(device)
68
  return self.model(image, **kwargs)
69
 
70
+ # Load models and preprocessor when the script starts
 
71
  # Controlnet Normal
72
+ model_id = "lllylyasviel/control_v11p_sd15_normalbae" # Corrected model ID based on common usage
73
  print("initializing controlnet")
 
74
  device = "cuda" if torch.cuda.is_available() else "cpu"
75
  controlnet = ControlNetModel.from_pretrained(
76
  model_id,
77
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
78
+ attn_implementation="flash_attention_2" if torch.cuda.is_available() else None,
79
  ).to(device)
80
+
81
  # Scheduler
82
  scheduler = DPMSolverMultistepScheduler.from_pretrained(
83
  "ashllay/stable-diffusion-v1-5-archive",
 
89
  prediction_type="epsilon",
90
  thresholding=False,
91
  denoise_final=True,
92
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
 
93
  )
94
+ # Schedulers do not need to be moved to device
 
95
 
96
  # Stable Diffusion Pipeline URL
97
  base_model_url = "https://huggingface.co/Lykon/AbsoluteReality/blob/main/AbsoluteReality_1.8.1_pruned.safetensors"
98
  print('loading pipe')
99
  pipe = StableDiffusionControlNetPipeline.from_single_file(
100
  base_model_url,
101
+ safety_checker=None, # Keep None for now, but consider enabling for public API
102
  controlnet=controlnet,
103
  scheduler=scheduler,
104
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
105
  ).to(device)
106
 
107
  print("loading preprocessor")
 
110
 
111
  # Load textual inversions
112
  try:
113
+ # List of textual inversions to load
114
+ textual_inversions = {
115
+ "EasyNegativeV2": "EasyNegativeV2.safetensors",
116
+ "badhandv4": "badhandv4.pt",
117
+ "fcNeg-neg": "fcNeg-neg.pt",
118
+ "HDA_Ahegao": "HDA_Ahegao.pt",
119
+ "HDA_Bondage": "HDA_Bondage.pt",
120
+ "HDA_pet_play": "HDA_pet_play.pt",
121
+ "HDA_unconventional_maid": "HDA_unconventional maid.pt",
122
+ "HDA_NakedHoodie": "HDA_NakedHoodie.pt",
123
+ "HDA_NunDress": "HDA_NunDress.pt",
124
+ "HDA_Shibari": "HDA_Shibari.pt",
125
+ }
126
+ for token, weight_name in textual_inversions.items():
127
+ try:
128
+ pipe.load_textual_inversion(
129
+ "broyang/hentaidigitalart_v20", weight_name=weight_name, token=token,
130
+ )
131
+ print(f"Loaded textual inversion: {token}")
132
+ except Exception as e:
133
+ print(f"Warning: Could not load textual inversion {weight_name}: {e}")
134
+
 
 
 
 
 
 
 
 
 
 
135
  except Exception as e:
136
+ print(f"Error during textual inversions loading process: {e}")
137
+
138
 
139
  print("---------------Loaded controlnet pipeline---------------")
140
  if torch.cuda.is_available():
141
  torch.cuda.empty_cache()
142
  gc.collect()
143
  print(f"CUDA memory allocated: {torch.cuda.max_memory_allocated(device='cuda') / 1e9:.2f} GB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
 
146
  def get_additional_prompt():
 
155
 
156
  def get_prompt(prompt, additional_prompt):
157
  interior = "design-style interior designed (interior space),tungsten white balance,captured with a DSLR camera using f/10 aperture, 1/60 sec shutter speed, ISO 400, 20mm focal length"
158
+
159
+ # Revised logic to prioritize the input prompt and combine with interior context and additional prompt
160
+ prompt_parts = []
161
+ if prompt:
162
+ prompt_parts.append(f"Photo from Pinterest of {prompt}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  else:
164
+ # If no specific prompt, use a default or random one (original code's 'boho chic' or random 'girls' prompts)
165
+ # Let's stick to interior design context, so maybe a default interior style if no prompt?
166
+ # Or, based on the original code's `if prompt == "":` block, it seemed to sometimes
167
+ # default to random 'girl' prompts. This might be unintended for an interior design API.
168
+ # Let's assume if no prompt is given, we still apply the interior context.
169
+ prompt_parts.append("Photo from Pinterest of interior space") # Default if no prompt
170
+
171
+ prompt_parts.append(interior)
172
+
173
+ if additional_prompt:
174
+ prompt_parts.append(additional_prompt)
175
+ # Note: The original `get_prompt` had a block that randomly selected 'girl' related prompts
176
+ # when the input `prompt` was empty. This seems out of place for an interior design API.
177
+ # I have removed that random selection logic to focus on interior design prompts.
178
+ # If you need that random girl prompt functionality, please clarify where/how it should be used.
179
+
180
+ return ", ".join(filter(None, prompt_parts))
181
 
182
 
183
  style_list = [
 
234
 
235
 
236
  def apply_style(style_name):
237
+ return styles.get(style_name, "")
 
 
238
 
 
 
239
 
240
+ # The core processing function, now called by the API endpoint
241
+ @torch.inference_mode() # Keep inference_mode here for efficiency
 
 
 
 
242
  def process_image_api(
243
  image: Image.Image,
244
  style_selection: str = "None",
245
  prompt: str = "",
246
  a_prompt: str = "",
247
  n_prompt: str = "EasyNegativeV2, fcNeg, (badhandv4:1.4), (worst quality, low quality, bad quality, normal quality:2.0), (bad hands, missing fingers, extra fingers:2.0)",
 
248
  image_resolution: int = 512,
249
  preprocess_resolution: int = 512,
250
  num_steps: int = 15,
251
  guidance_scale: float = 5.5,
252
+ seed: int = -1,
253
  ):
254
  """
255
  Processes an input image to generate a new image based on style and prompts.
 
260
  prompt: Custom design prompt.
261
  a_prompt: Additional positive prompt.
262
  n_prompt: Negative prompt.
 
263
  image_resolution: Resolution for the output image.
264
  preprocess_resolution: Resolution for the preprocessor.
265
  num_steps: Number of inference steps.
 
269
  Returns:
270
  A PIL Image of the generated result.
271
  """
 
272
  current_seed = seed if seed != -1 else random.randint(0, MAX_SEED)
273
  generator = torch.cuda.manual_seed(current_seed) if torch.cuda.is_available() else torch.manual_seed(current_seed)
274
 
 
275
  if preprocessor.name != "NormalBae":
276
  preprocessor.load("NormalBae")
277
 
 
278
  preprocessor.model.to("cuda" if torch.cuda.is_available() else "cpu")
279
 
 
280
  control_image = preprocessor(
281
  image=image,
282
  image_resolution=image_resolution,
 
285
 
286
  # Construct the full prompt
287
  if style_selection and style_selection != "None":
 
288
  style_prompt = apply_style(style_selection)
 
289
  prompt_parts = [f"Photo from Pinterest of {prompt}" if prompt else None, style_prompt if style_prompt else None, a_prompt if a_prompt else None]
290
  full_prompt = ", ".join(filter(None, prompt_parts))
291
  else:
 
292
  full_prompt = get_prompt(prompt, a_prompt)
293
 
294
+
295
  negative_prompt = str(n_prompt)
296
  print(f"Using prompt: {full_prompt}")
297
  print(f"Using negative prompt: {negative_prompt}")
298
  print(f"Using seed: {current_seed}")
299
 
 
 
 
300
  pipe.to("cuda" if torch.cuda.is_available() else "cpu")
301
 
302
+ with torch.no_grad():
303
  initial_result = pipe(
304
  prompt=full_prompt,
305
  negative_prompt=negative_prompt,
306
  guidance_scale=guidance_scale,
307
+ num_images_per_prompt=1,
308
  num_inference_steps=num_steps,
309
  generator=generator,
310
  image=control_image,
311
  ).images[0]
312
 
313
+ # Save and upload results (optional)
 
 
314
  try:
315
  timestamp = int(time.time())
 
 
316
  results_path = f"{timestamp}_output.jpg"
 
317
  imageio.imsave(results_path, initial_result)
318
 
319
+ if API_KEY:
 
 
320
  print(f"Uploading result image to broyang/interior-ai-outputs/{results_path}")
321
  try:
322
  api.upload_file(
 
325
  repo_id="broyang/interior-ai-outputs",
326
  repo_type="dataset",
327
  token=API_KEY,
328
+ run_as_future=True,
329
  )
 
 
 
 
 
 
 
 
 
330
  except Exception as e:
331
  print(f"Error uploading file to Hugging Face Hub: {e}")
332
  else:
 
335
  except Exception as e:
336
  print(f"Error saving or uploading image: {e}")
337
 
 
338
  return initial_result
339
 
340
+ # Define a Pydantic model for the request body parameters (optional, but good practice)
341
+ # class ImageParameters(BaseModel):
342
+ # style_selection: str = "None"
343
+ # prompt: str = ""
344
+ # a_prompt: str = ""
345
+ # n_prompt: str = "EasyNegativeV2, fcNeg, (badhandv4:1.4), (worst quality, low quality, bad quality, normal quality:2.0), (bad hands, missing fingers, extra fingers:2.0)"
346
+ # image_resolution: int = 512
347
+ # preprocess_resolution: int = 512
348
+ # num_steps: int = 15
349
+ # guidance_scale: float = 5.5
350
+ # seed: int = -1
351
+
352
+ # Define the API endpoint
353
+ @app.post("/generate-image/")
354
+ async def generate_image(
355
+ file: UploadFile = File(...), # Input image file
356
+ style_selection: str = Form("None"), # Parameters from form data
357
+ prompt: str = Form(""),
358
+ a_prompt: str = Form(""),
359
+ n_prompt: str = Form("EasyNegativeV2, fcNeg, (badhandv4:1.4), (worst quality, low quality, bad quality, normal quality:2.0), (bad hands, missing fingers, extra fingers:2.0)"),
360
+ image_resolution: int = Form(512),
361
+ preprocess_resolution: int = Form(512),
362
+ num_steps: int = Form(15),
363
+ guidance_scale: float = Form(5.5),
364
+ seed: int = Form(-1),
365
+ ):
366
+ """
367
+ API endpoint to generate an interior design image based on an input image and parameters.
368
+
369
+ Expects a POST request with form-data including:
370
+ - file: The input image file (UploadFile).
371
+ - style_selection: The design style name (string).
372
+ - prompt: Custom design prompt (string).
373
+ - a_prompt: Additional positive prompt (string).
374
+ - n_prompt: Negative prompt (string).
375
+ - image_resolution: Output image resolution (int).
376
+ - preprocess_resolution: Preprocessor resolution (int).
377
+ - num_steps: Number of inference steps (int).
378
+ - guidance_scale: Guidance scale (float).
379
+ - seed: Random seed (int, use -1 for random).
380
+
381
+ Returns:
382
+ The generated image as a JPEG file.
383
+ """
384
+ try:
385
+ # Read the uploaded image file
386
+ image_data = await file.read()
387
+ input_image = Image.open(io.BytesIO(image_data)).convert("RGB")
388
+
389
+ # Process the image using the core logic
390
+ generated_image = process_image_api(
391
+ image=input_image,
392
+ style_selection=style_selection,
393
+ prompt=prompt,
394
+ a_prompt=a_prompt,
395
+ n_prompt=n_prompt,
396
+ image_resolution=image_resolution,
397
+ preprocess_resolution=preprocess_resolution,
398
+ num_steps=num_steps,
399
+ guidance_scale=guidance_scale,
400
+ seed=seed,
401
+ )
402
+
403
+ # Return the generated image as a streaming response
404
+ buffer = io.BytesIO()
405
+ generated_image.save(buffer, format="JPEG")
406
+ buffer.seek(0)
407
+
408
+ return StreamingResponse(buffer, media_type="image/jpeg")
409
+
410
+ except Exception as e:
411
+ print(f"An error occurred during processing: {e}")
412
+ raise HTTPException(status_code=500, detail=f"Internal Server Error: {e}")
413
+
414
+ # Entry point to run the FastAPI application using Uvicorn
415
+ if __name__ == "__main__":
416
+ # The host "0.0.0.0" makes the server accessible externally within the container
417
+ # The port is taken from the environment variable PORT, which Hugging Face Spaces sets
418
+ uvicorn.run(app, host="0.0.0.0", port=port)
419
+