Deadmon commited on
Commit
2a2aa3b
·
verified ·
1 Parent(s): e3ccb0e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -126
app.py CHANGED
@@ -1,145 +1,163 @@
1
- from huggingface_hub import upload_file, create_repo
 
2
  import gradio as gr
3
  import os
4
- import requests
5
- import tempfile
6
- import re
7
 
8
- article = """
9
- Some things to note:
10
- * To obtain the download link, right click "Download" on CivitAI and Copy Link Address.
11
- * If the model requires login, you must provide your CivitAI API Key.
12
- * API Key location: https://civitai.com/user/account -> API Keys.
13
- """
14
 
15
- # Updated signature: Added defaults (=None) to ensure Gradio maps inputs correctly
16
- def download_locally_and_upload_to_hub(civit_url, repo_id, hf_token=None, civitai_api_key=None, progress=gr.Progress()):
17
- # Sanitize inputs
18
- hf_token = hf_token.strip() if hf_token else None
19
- civitai_api_key = civitai_api_key.strip() if civitai_api_key else None
20
 
21
- if not civit_url:
22
- return "Error: Please provide a CivitAI URL."
23
- if not repo_id:
24
- return "Error: Please provide a Hugging Face Repo ID."
25
- if not hf_token:
26
- return "Error: Please provide a Hugging Face Write Token."
27
-
28
- # Initialize state
29
- repo_created = False
30
- file_committed = False
31
- commit_info = None
32
-
33
- # User-Agent headers to avoid 403 errors
34
- headers = {
35
- "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
36
- }
37
- if civitai_api_key:
38
- headers["Authorization"] = f"Bearer {civitai_api_key}"
39
- headers["Accept"] = "application/json"
40
-
41
- # --- STEP 1: DOWNLOAD ---
42
- progress(0, desc="Connecting to CivitAI...")
43
  try:
44
- response = requests.get(civit_url, headers=headers, stream=True)
45
- response.raise_for_status()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  except Exception as e:
47
- return f"Error connecting to CivitAI: {e}"
48
 
49
- # Extract filename
50
- filename = None
51
- cd = response.headers.get("Content-Disposition")
52
- if cd:
53
- fname_match = re.findall('filename="?([^"]+)"?', cd)
54
- if fname_match:
55
- filename = fname_match[0]
56
-
57
- if not filename:
58
- filename = civit_url.split("/")[-1]
59
 
60
- # Clean filename and ensure extension
61
- filename = filename.split("?")[0] # Remove query params if any
62
- if "." not in filename:
63
- filename += ".safetensors"
64
 
65
- # --- STEP 2: SAVE & UPLOAD ---
66
- total_size = int(response.headers.get('content-length', 0))
67
- block_size = 1024 * 1024 # 1MB chunks
 
 
 
68
 
69
- with tempfile.TemporaryDirectory() as temp_dir:
70
- local_path = os.path.join(temp_dir, filename)
71
-
72
- try:
73
- downloaded = 0
74
- with open(local_path, "wb") as f:
75
- for chunk in response.iter_content(chunk_size=block_size):
76
- if chunk:
77
- f.write(chunk)
78
- downloaded += len(chunk)
79
- # Update Progress
80
- if total_size > 0:
81
- progress(downloaded / total_size, desc=f"Downloading {filename}...")
82
- else:
83
- progress(0.5, desc="Downloading (size unknown)...")
84
- except Exception as e:
85
- return f"Error saving locally: {e}"
86
 
87
- # Create Repo and Upload
88
- if repo_id:
89
- progress(0.8, desc="Creating HF Repo...")
90
- try:
91
- create_repo(repo_id=repo_id, exist_ok=True, token=hf_token)
92
- repo_created = True
93
- except Exception as e:
94
- return f"Error creating repo: {e}"
95
-
96
- if repo_created:
97
- progress(0.9, desc=f"Uploading {filename} to Hub...")
98
- try:
99
- commit_info = upload_file(
100
- repo_id=repo_id,
101
- path_or_fileobj=local_path,
102
- path_in_repo=filename,
103
- token=hf_token
104
- )
105
- file_committed = True
106
- except Exception as e:
107
- return f"Error uploading file: {e}"
108
-
109
- if file_committed:
110
- # Construct URL safely
111
- url = commit_info.commit_url if hasattr(commit_info, 'commit_url') else f"https://huggingface.co/{repo_id}/blob/main/{filename}"
112
- return f"## Success! \n Model pushed to: [{url}]({url})"
113
 
114
- return "Unknown error occurred."
115
 
116
- def get_gradio_demo():
117
- with gr.Blocks() as demo:
118
- gr.Markdown("# Upload CivitAI checkpoints to the HF Hub 🤗")
119
- gr.Markdown(article)
120
-
121
- with gr.Row():
122
- u_input = gr.Textbox(label="CivitAI Download URL")
123
- r_input = gr.Textbox(label="Repo ID (user/model)")
124
-
125
- with gr.Row():
126
- t_input = gr.Textbox(label="HF Write Token", type="password")
127
- k_input = gr.Textbox(label="CivitAI API Key (Optional)", type="password")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
- btn = gr.Button("Download & Upload", variant="primary")
130
- out = gr.Markdown()
 
 
 
 
 
 
 
 
131
 
132
- # NOTE: logic to handle inputs matches the function arguments exactly
133
- btn.click(
134
- fn=download_locally_and_upload_to_hub,
135
- inputs=[u_input, r_input, t_input, k_input],
136
- outputs=out
137
- )
 
 
 
138
 
139
- return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  if __name__ == "__main__":
142
- demo = get_gradio_demo()
143
- # queue() is required for progress bars to work correctly
144
- demo.queue()
145
- demo.launch(show_error=True)
 
1
+ import torch
2
+ import spaces
3
  import gradio as gr
4
  import os
5
+ from diffusers import DiffusionPipeline
6
+ from huggingface_hub import list_repo_files
 
7
 
8
+ # ================= CONFIGURATION =================
9
+ REPO_ID = "deazbooney/Z-Image-Turbo-NSFW"
10
+ # =================================================
 
 
 
11
 
12
+ print(f"Initializing pipeline for: {REPO_ID}...")
 
 
 
 
13
 
14
+ def load_pipeline(repo_id):
15
+ """
16
+ Robust loader that handles both standard Diffusers folders
17
+ and single-file .safetensors checkpoints (CivitAI style).
18
+ """
19
+ # 1. Try to find a single .safetensors file in the repo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  try:
21
+ print(f"Scanning repository {repo_id} for checkpoint files...")
22
+ files = list_repo_files(repo_id)
23
+ # Filter for .safetensors
24
+ safetensors_files = [f for f in files if f.endswith(".safetensors") and "vae" not in f]
25
+
26
+ if safetensors_files:
27
+ # Pick the most likely model file (e.g., largest or specifically named)
28
+ checkpoint_file = safetensors_files[0]
29
+ print(f"✅ Detected single-file checkpoint: {checkpoint_file}")
30
+
31
+ pipe = DiffusionPipeline.from_single_file(
32
+ repo_id,
33
+ filename=checkpoint_file,
34
+ torch_dtype=torch.bfloat16,
35
+ low_cpu_mem_usage=False,
36
+ use_safetensors=True
37
+ )
38
+ return pipe
39
  except Exception as e:
40
+ print(f"⚠️ Single-file detection failed: {e}")
41
 
42
+ # 2. Fallback to standard Diffusers folder structure
43
+ print("🔄 Falling back to standard folder loading...")
44
+ pipe = DiffusionPipeline.from_pretrained(
45
+ repo_id,
46
+ torch_dtype=torch.bfloat16,
47
+ low_cpu_mem_usage=False,
48
+ )
49
+ return pipe
 
 
50
 
51
+ # Load the model
52
+ pipe = load_pipeline(REPO_ID)
 
 
53
 
54
+ # Move to GPU globally (ZeroGPU handles swapping)
55
+ try:
56
+ pipe.to("cuda")
57
+ print("✅ Model moved to CUDA")
58
+ except Exception as e:
59
+ print(f"⚠️ Could not move to CUDA immediately (Normal for some build environments): {e}")
60
 
61
+ # ======== GENERATION FUNCTION WITH @spaces.GPU ========
62
+ @spaces.GPU
63
+ def generate_image(prompt, height, width, num_inference_steps, seed, randomize_seed, progress=gr.Progress(track_tqdm=True)):
64
+ """Generate an image from the given prompt using ZeroGPU."""
65
+
66
+ if randomize_seed:
67
+ seed = torch.randint(0, 2**32 - 1, (1,)).item()
68
+
69
+ generator = torch.Generator("cuda").manual_seed(int(seed))
70
+
71
+ # Ensure pipe is on cuda (redundant check for safety)
72
+ if pipe.device.type != "cuda":
73
+ pipe.to("cuda")
 
 
 
 
74
 
75
+ image = pipe(
76
+ prompt=prompt,
77
+ height=int(height),
78
+ width=int(width),
79
+ num_inference_steps=int(num_inference_steps),
80
+ guidance_scale=0.0, # Turbo models use 0.0
81
+ generator=generator,
82
+ ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
+ return image, seed
85
 
86
+ # Example prompts
87
+ examples = [
88
+ ["Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp, bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda, blurred colorful distant lights."],
89
+ ["A majestic dragon soaring through clouds at sunset, scales shimmering with iridescent colors, detailed fantasy art style"],
90
+ ["Cozy coffee shop interior, warm lighting, rain on windows, plants on shelves, vintage aesthetic, photorealistic"],
91
+ ["Astronaut riding a horse on Mars, cinematic lighting, sci-fi concept art, highly detailed"],
92
+ ["Portrait of a wise old wizard with a long white beard, holding a glowing crystal staff, magical forest background"],
93
+ ]
94
+
95
+ # Custom theme
96
+ custom_theme = gr.themes.Soft(
97
+ primary_hue="yellow",
98
+ secondary_hue="amber",
99
+ neutral_hue="slate",
100
+ font=gr.themes.GoogleFont("Inter"),
101
+ ).set(
102
+ button_primary_background_fill="*primary_500",
103
+ button_primary_background_fill_hover="*primary_600",
104
+ )
105
+
106
+ # Build the Gradio interface
107
+ with gr.Blocks(fill_height=True, theme=custom_theme) as demo:
108
+ gr.Markdown(
109
+ """
110
+ # 🎨 Z-Image-Turbo-NSFW
111
+ **Ultra-fast AI image generation** • Powered by ZeroGPU
112
+ """,
113
+ elem_classes="header-text"
114
+ )
115
+
116
+ with gr.Row(equal_height=False):
117
+ with gr.Column(scale=1, min_width=320):
118
+ prompt = gr.Textbox(
119
+ label="✨ Your Prompt",
120
+ placeholder="Describe the image...",
121
+ lines=5,
122
+ autofocus=True,
123
+ )
124
 
125
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
126
+ with gr.Row():
127
+ height = gr.Slider(512, 2048, value=1024, step=64, label="Height")
128
+ width = gr.Slider(512, 2048, value=1024, step=64, label="Width")
129
+
130
+ num_inference_steps = gr.Slider(1, 20, value=9, step=1, label="Steps")
131
+
132
+ with gr.Row():
133
+ randomize_seed = gr.Checkbox(label="🎲 Random Seed", value=True)
134
+ seed = gr.Number(label="Seed", value=42, visible=False)
135
 
136
+ randomize_seed.change(
137
+ lambda x: gr.Number(visible=not x),
138
+ inputs=randomize_seed,
139
+ outputs=seed
140
+ )
141
+
142
+ generate_btn = gr.Button("🚀 Generate", variant="primary", size="lg")
143
+
144
+ gr.Examples(examples, inputs=prompt, label="Examples")
145
 
146
+ with gr.Column(scale=1, min_width=320):
147
+ output_image = gr.Image(label="Result", type="pil", height=600)
148
+ used_seed = gr.Number(label="Seed Used", interactive=False)
149
+
150
+ # Event wiring
151
+ generate_btn.click(
152
+ fn=generate_image,
153
+ inputs=[prompt, height, width, num_inference_steps, seed, randomize_seed],
154
+ outputs=[output_image, used_seed],
155
+ )
156
+ prompt.submit(
157
+ fn=generate_image,
158
+ inputs=[prompt, height, width, num_inference_steps, seed, randomize_seed],
159
+ outputs=[output_image, used_seed],
160
+ )
161
 
162
  if __name__ == "__main__":
163
+ demo.launch()