File size: 6,782 Bytes
26575f5
 
e964fca
26575f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
#  generation_function.py

import spaces # Import spaces first!
import torch
import numpy as np
import time
import imageio
import gc
from PIL import Image

# Import necessary components from your new modules
from model_loader import controlnet_pipe, inpaint_pipe, api
from preprocessor import Preprocessor
from style_utils import apply_style, style_list
from utils import randomize_seed_fn
from config import API_KEY

preprocessor = Preprocessor()
# Preprocessor is loaded in app.py

@spaces.GPU(duration=12)
@torch.inference_mode()
def generate_interior_design(
        image_np: np.ndarray,
        mask_np: np.ndarray | None, # Add mask input (can be None)
        mode: str, # Add mode selection input
        style_selection: str,
        prompt: str,
        a_prompt: str,
        n_prompt: str,
        num_images: int, # Note: Pipeline currently only generates 1 image
        image_resolution: int,
        preprocess_resolution: int,
        num_steps: int,
        guidance_scale: float,
        seed: int,
        randomize_seed: bool,
):
    # Convert numpy arrays to PIL Images
    image = Image.fromarray(image_np.astype(np.uint8)).convert("RGB")
    mask = Image.fromarray(mask_np[:, :, 0].astype(np.uint8), 'L') if mask_np is not None else None # Convert mask to grayscale PIL Image

    # Apply seed randomization
    current_seed = randomize_seed_fn(seed, randomize_seed)
    generator = torch.cuda.manual_seed(current_seed) if torch.cuda.is_available() else torch.manual_seed(current_seed)
    print(f"Using processed seed: {current_seed}")

    # Construct the full prompt (can be used by both pipelines)
    style_prompt_text = apply_style(style_selection)
    prompt_parts = []
    if prompt:
        prompt_parts.append(f"Photo from Pinterest of {prompt}")
    else:
        prompt_parts.append("Photo from Pinterest of interior space")

    if style_prompt_text:
        prompt_parts.append(style_prompt_text)

    if a_prompt:
        prompt_parts.append(a_prompt)

    full_prompt = ", ".join(filter(None, prompt_parts))
    negative_prompt = str(n_prompt)

    print(f"Using prompt: {full_prompt}")
    print(f"Using negative prompt: {negative_prompt}")
    print(f"Selected mode: {mode}")

    initial_result = None

    if mode == "ControlNet":
        if preprocessor.name != "NormalBae":
             preprocessor.load("NormalBae")

        # Ensure preprocessor is on the correct device
        preprocessor_device = "cuda" if torch.cuda.is_available() else "cpu"
        if hasattr(preprocessor.model, 'device') and preprocessor.model.device.type != preprocessor_device:
            print(f"Moving preprocessor model to {preprocessor_device}")
            try:
                preprocessor.model.to(preprocessor_device)
            except Exception as e:
                print(f"Error moving preprocessor model to {preprocessor_device}: {e}")
                pass

        control_image = preprocessor(
            image=image,
            image_resolution=image_resolution,
            detect_resolution=preprocess_resolution,
        )

        controlnet_pipe_device = "cuda" if torch.cuda.is_available() else "cpu"
        if hasattr(controlnet_pipe, 'device') and controlnet_pipe.device.type != controlnet_pipe_device:
             print(f"Moving controlnet pipe to {controlnet_pipe_device}")
             try:
                 controlnet_pipe.to(controlnet_pipe_device)
             except Exception as e:
                 print(f"Error moving controlnet pipe to {controlnet_pipe_device}: {e}")

        with torch.no_grad():
            initial_result = controlnet_pipe(
                prompt=full_prompt,
                negative_prompt=negative_prompt,
                guidance_scale=guidance_scale,
                num_images_per_prompt=1,
                num_inference_steps=num_steps,
                generator=generator,
                image=control_image,
            ).images[0]

    elif mode == "Inpainting":
        if mask is None:
            raise gr.Error("Inpainting mode requires a mask.") # Provide user feedback

        inpaint_pipe_device = "cuda" if torch.cuda.is_available() else "cpu"
        if hasattr(inpaint_pipe, 'device') and inpaint_pipe.device.type != inpaint_pipe_device:
             print(f"Moving inpaint pipe to {inpaint_pipe_device}")
             try:
                 inpaint_pipe.to(inpaint_pipe_device)
             except Exception as e:
                 print(f"Error moving inpaint pipe to {inpaint_pipe_device}: {e}")

        with torch.no_grad():
             initial_result = inpaint_pipe(
                 prompt=full_prompt,
                 negative_prompt=negative_prompt,
                 image=image, # Pass original image
                 mask_image=mask, # Pass the mask image
                 guidance_scale=guidance_scale,
                 num_inference_steps=num_steps,
                 generator=generator,
             ).images[0]

    # Save and upload results (optional) - This part can remain the same
    try:
        if initial_result: # Only save/upload if a result was generated
            timestamp = int(time.time())
            results_path = f"{timestamp}_output.jpg"
            imageio.imsave(results_path, initial_result)

            if API_KEY:
                print(f"Uploading result image to broyang/interior-ai-outputs/{results_path}")
                try:
                    api.upload_file(
                        path_or_fileobj=results_path,
                        path_in_repo=results_path,
                        repo_id="broyang/interior-ai-outputs",
                        repo_type="dataset",
                        token=API_KEY,
                        run_as_future=True,
                    )
                except Exception as e:
                    print(f"Error uploading file to Hugging Face Hub: {e}")
            else:
                print("Hugging Face API Key not found, skipping file upload.")

    except Exception as e:
        print(f"Error saving or uploading image: {e}")

    # Clean up CUDA memory
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        gc.collect()
        if initial_result:
             print(f"CUDA memory allocated after generation: {torch.cuda.max_memory_allocated(device='cuda') / 1e9:.2f} GB")
        else:
             print(f"CUDA memory allocated: {torch.cuda.max_memory_allocated(device='cuda') / 1e9:.2f} GB")


    if initial_result is None:
         # Return a blank image or an error message if no result was generated
         # This might happen if an unimplemented mode was selected
         print("No result generated for the selected mode.")
         return Image.new('RGB', (512, 512), (255, 255, 255)) # Return a blank white image

    return initial_result