Spaces:
Running
Running
File size: 11,419 Bytes
73cfa03 31374d6 73cfa03 31374d6 73cfa03 31374d6 73cfa03 31374d6 73cfa03 31374d6 73cfa03 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 | """FLUX.2 Klein 4B - Free CPU Space with dynamic LoRA search from HuggingFace Hub"""
import os, time, gc, shutil
from pathlib import Path
from PIL import Image
import requests as req
# ---------------------------------------------------------------------------
# Thread config (cgroup-aware)
# ---------------------------------------------------------------------------
def get_cpu_count() -> int:
try:
with open("/sys/fs/cgroup/cpu.max") as f:
q, p = f.read().strip().split()
if q != "max": return max(1, int(q) // int(p))
except Exception: pass
try:
with open("/sys/fs/cgroup/cpu/cpu.cfs_quota_us") as f: q = int(f.read().strip())
with open("/sys/fs/cgroup/cpu/cpu.cfs_period_us") as f: p = int(f.read().strip())
if q > 0: return max(1, q // p)
except Exception: pass
return max(1, os.cpu_count() or 2)
N_THREADS = get_cpu_count()
for k in ["OMP_NUM_THREADS", "OPENBLAS_NUM_THREADS", "MKL_NUM_THREADS"]:
os.environ.setdefault(k, str(N_THREADS))
print(f"[init] CPU threads: {N_THREADS}")
# ---------------------------------------------------------------------------
# Model resolution
# ---------------------------------------------------------------------------
HF_CACHE = Path(os.environ.get("HF_HOME", Path.home() / ".cache" / "huggingface" / "hub"))
def find_model(filename: str) -> str:
for d in [Path("."), Path("models")]:
if (d / filename).exists(): return str(d / filename)
for p in HF_CACHE.rglob(filename): return str(p)
raise FileNotFoundError(f"Not found: {filename}")
# ---------------------------------------------------------------------------
# Load base models
# ---------------------------------------------------------------------------
from huggingface_hub import hf_hub_download, list_repo_files
from stable_diffusion_cpp import StableDiffusion
DIFFUSION_FILE = "flux-2-klein-4b-Q4_K_M.gguf"
LLM_FILE = "qwen3-4b-abl-q4_0.gguf"
VAE_FILE = "flux2-vae.safetensors"
print("[init] Locating models...")
diffusion_path = find_model(DIFFUSION_FILE)
vae_path = find_model(VAE_FILE)
try:
llm_path = find_model(LLM_FILE)
except FileNotFoundError:
print("[init] Downloading uncensored text encoder...")
llm_path = hf_hub_download(
repo_id="WeReCooking/flux2-klein-4B-uncensored-text-encoder",
filename=LLM_FILE,
)
print(f"[init] Diffusion: {diffusion_path}")
print(f"[init] LLM: {llm_path}")
print(f"[init] VAE: {vae_path}")
# ---------------------------------------------------------------------------
# LoRA management
# ---------------------------------------------------------------------------
LORA_DIR = "/tmp/loras"
os.makedirs(LORA_DIR, exist_ok=True)
DOWNLOADED_LORAS: dict[str, str] = {}
def fetch_all_loras(query: str = "") -> list[str]:
search = f"klein 4b {query}".strip()
try:
r = req.get("https://huggingface.co/api/models", params={
"search": search, "filter": "lora",
"sort": "downloads", "direction": "-1", "limit": 50,
}, timeout=10)
r.raise_for_status()
results = []
for m in r.json():
mid = m.get("id", "")
tags = m.get("tags", [])
if "lora" in tags or "lora" in mid.lower():
results.append(mid)
return results if results else []
except Exception as e:
print(f"[lora] Search error: {e}")
return []
def download_lora(repo_id: str) -> tuple[str, str]:
if not repo_id or repo_id.startswith("("):
return "", "Select a LoRA first"
try:
files = list_repo_files(repo_id)
sf_files = [f for f in files if f.endswith(".safetensors")]
if not sf_files:
return "", f"No .safetensors found in {repo_id}"
target = sf_files[0]
for f in sf_files:
if "lora" in f.lower() or "adapter" in f.lower():
target = f
break
label = f"{repo_id}/{target}"
lora_name = label.replace("/", "_").replace("-", "_").replace(".", "_")
lora_name = lora_name.rsplit("_safetensors", 1)[0]
lora_dst = os.path.join(LORA_DIR, f"{lora_name}.safetensors")
if label in DOWNLOADED_LORAS:
size_mb = os.path.getsize(lora_dst) / 1024**2
return label, f"Already cached ({size_mb:.0f} MB)"
print(f"[lora] Downloading {repo_id}/{target}...")
src = hf_hub_download(repo_id=repo_id, filename=target)
shutil.copy2(src, lora_dst)
size_mb = os.path.getsize(lora_dst) / 1024**2
DOWNLOADED_LORAS[label] = lora_name
print(f"[lora] Downloaded: {label} ({size_mb:.0f} MB)")
return label, f"Downloaded: {label} ({size_mb:.0f} MB)"
except Exception as e:
return "", f"Failed: {e}"
# ---------------------------------------------------------------------------
# Engine
# ---------------------------------------------------------------------------
SD_ENGINE = {"instance": None, "lora_state": None}
def _reload_engine():
lora_files = set(os.listdir(LORA_DIR)) if os.path.exists(LORA_DIR) else set()
state_key = frozenset(lora_files)
if SD_ENGINE["instance"] is not None and SD_ENGINE["lora_state"] == state_key:
return
print(f"[engine] Loading (loras: {len(lora_files)})...")
t0 = time.time()
kwargs = dict(
diffusion_model_path=diffusion_path, llm_path=llm_path, vae_path=vae_path,
diffusion_flash_attn=True, n_threads=N_THREADS, verbose=True,
)
if lora_files:
kwargs["lora_model_dir"] = LORA_DIR
SD_ENGINE["instance"] = StableDiffusion(**kwargs)
SD_ENGINE["lora_state"] = state_key
print(f"[engine] Loaded in {time.time()-t0:.1f}s")
def get_engine():
if SD_ENGINE["instance"] is None:
_reload_engine()
return SD_ENGINE["instance"]
_reload_engine()
print("[init] Fetching Klein 4B LoRA catalog...")
INITIAL_LORAS = fetch_all_loras("")
print(f"[init] Found {len(INITIAL_LORAS)} LoRAs")
# ---------------------------------------------------------------------------
# Inference
# ---------------------------------------------------------------------------
RESOLUTIONS = ["512x512", "768x768", "1024x1024", "1024x768", "768x1024", "1024x576", "576x1024"]
def parse_res(s):
w, h = s.split("x")
return int(w), int(h)
def generate(prompt, ref_image, resolution, steps, seed, lora_strength, active_loras, progress=None):
try:
gc.collect()
sd = get_engine()
w, h = parse_res(resolution)
steps, seed = int(steps), int(seed) if int(seed) >= 0 else -1
actual_prompt = prompt
lora_tags = []
if active_loras:
for label in active_loras:
lora_name = DOWNLOADED_LORAS.get(label)
if lora_name:
actual_prompt = f'<lora:{lora_name}:{lora_strength:.2f}> {actual_prompt}'
lora_tags.append(label.split("/")[-1])
is_edit = ref_image is not None
mode = "edit" if is_edit else "gen"
print(f"[{mode}] {w}x{h} steps={steps} seed={seed} loras={lora_tags}")
t0 = time.time()
kwargs = dict(prompt=actual_prompt, width=w, height=h, sample_steps=steps, cfg_scale=1.0, seed=seed)
if is_edit:
kwargs["ref_images"] = [ref_image]
images = sd.generate_image(**kwargs)
elapsed = time.time() - t0
lora_info = f" +{len(lora_tags)} LoRA(s)" if lora_tags else ""
edit_info = " [edit]" if is_edit else ""
status = f"{elapsed:.1f}s | {w}x{h}, {steps} steps, seed {seed}{lora_info}{edit_info}"
print(f"[{mode}] {status}")
return (images[0] if images else None), status
except Exception as e:
import traceback; traceback.print_exc()
return None, f"Error: {e}"
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
import gradio as gr
with gr.Blocks(theme="NoCrypt/miku", title="FLUX.2 Klein 4B CPU") as demo:
gr.Markdown(
"# FLUX.2 Klein 4B / Free CPU\n"
"Type a prompt to generate. Upload a reference image to edit it instead. "
"Expect **15-30 min** per image at 512x512 on free CPU."
)
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Describe what to generate or edit...")
ref_image = gr.Image(label="Reference Image (optional, for editing)", type="pil")
resolution = gr.Dropdown(choices=RESOLUTIONS, value="512x512", label="Resolution")
with gr.Row():
steps = gr.Slider(2, 8, value=4, step=1, label="Steps", scale=1)
seed = gr.Number(value=-1, label="Seed", precision=0, scale=1)
lora_strength = gr.Slider(0.1, 1.5, value=0.8, step=0.05, label="LoRA str", scale=1)
with gr.Accordion("LoRA (search Klein 4B LoRAs on HuggingFace)", open=False):
lora_search = gr.Dropdown(
choices=INITIAL_LORAS, value=None,
label="Search LoRA repos (type to filter, select to download)",
filterable=True, allow_custom_value=True, interactive=True,
)
lora_status = gr.Textbox(label="Status", interactive=False, value="No LoRA active")
active_loras = gr.Dropdown(
choices=[], value=[], multiselect=True, interactive=True,
label="Active LoRAs (click X to remove)",
)
gen_btn = gr.Button("Generate / Edit", variant="primary", size="lg")
with gr.Column(scale=1):
output_image = gr.Image(label="Output", type="pil")
status_text = gr.Textbox(label="Status", interactive=False)
def on_search_type(query):
if not query or query in INITIAL_LORAS:
return gr.update(choices=INITIAL_LORAS)
results = fetch_all_loras(query)
return gr.update(choices=results if results else INITIAL_LORAS)
def on_lora_select(repo_id, current_active):
if not repo_id or repo_id.startswith("("):
return current_active or [], "Select a LoRA", gr.update()
label, status_msg = download_lora(repo_id)
if not label:
return current_active or [], status_msg, gr.update()
_reload_engine()
active = list(current_active) if current_active else []
if label not in active:
active.append(label)
all_downloaded = list(DOWNLOADED_LORAS.keys())
return gr.update(choices=all_downloaded, value=active), status_msg, gr.update(value=None)
lora_search.input(fn=on_search_type, inputs=[lora_search], outputs=[lora_search])
lora_search.select(fn=on_lora_select, inputs=[lora_search, active_loras], outputs=[active_loras, lora_status, lora_search])
gen_btn.click(fn=generate, inputs=[prompt, ref_image, resolution, steps, seed, lora_strength, active_loras], outputs=[output_image, status_text])
gr.Markdown("---\nsd.cpp Q4_K_M | Uncensored encoder | "
"[BFL](https://bfl.ai/models/flux-2-klein) | [sd.cpp](https://github.com/leejet/stable-diffusion.cpp) | "
"[Browse LoRAs](https://huggingface.co/models?search=klein+4b&filter=lora)")
demo.queue().launch(ssr_mode=False, show_error=True)
|