prithivMLmods commited on
Commit
d44bc10
·
verified ·
1 Parent(s): 4a15cf2

update app

Browse files
Files changed (1) hide show
  1. app.py +130 -242
app.py CHANGED
@@ -7,7 +7,6 @@ import base64
7
  import random
8
  import zipfile
9
  import threading
10
- import concurrent.futures
11
  from pathlib import Path
12
  from typing import List, Optional
13
 
@@ -46,13 +45,6 @@ else:
46
  DEVICE_LABEL = str(device).lower()
47
 
48
  # --- Model Loading ---
49
- print("Loading 4B Distilled model (Standard VAE)...")
50
- pipe_standard = Flux2KleinPipeline.from_pretrained(
51
- "black-forest-labs/FLUX.2-klein-4B",
52
- torch_dtype=dtype,
53
- ).to(device)
54
- pipe_standard.enable_model_cpu_offload()
55
-
56
  print("Loading Small Decoder VAE...")
57
  vae_small = AutoencoderKLFlux2.from_pretrained(
58
  "black-forest-labs/FLUX.2-small-decoder",
@@ -60,15 +52,14 @@ vae_small = AutoencoderKLFlux2.from_pretrained(
60
  ).to(device)
61
 
62
  print("Loading 4B Distilled model (Small Decoder VAE)...")
63
- pipe_small_decoder = Flux2KleinPipeline.from_pretrained(
64
  "black-forest-labs/FLUX.2-klein-4B",
65
  vae=vae_small,
66
  torch_dtype=dtype,
67
  ).to(device)
68
- pipe_small_decoder.enable_model_cpu_offload()
69
 
70
- pipe_lock_standard = threading.Lock()
71
- pipe_lock_small = threading.Lock()
72
 
73
  # --- Utility Functions ---
74
  def calc_dimensions(pil_img: Image.Image):
@@ -89,7 +80,7 @@ def calc_dimensions(pil_img: Image.Image):
89
  def parse_and_resize_images(image_paths: List[str], width: int, height: int):
90
  if not image_paths:
91
  return None
92
-
93
  resized = []
94
  for path in image_paths:
95
  try:
@@ -97,11 +88,11 @@ def parse_and_resize_images(image_paths: List[str], width: int, height: int):
97
  resized.append(img.resize((width, height), Image.LANCZOS))
98
  except Exception as e:
99
  print(f"Skipping invalid image: {e}")
100
-
101
  return resized if resized else None
102
 
103
- def run_pipeline(pipe, lock, kwargs, seed):
104
- with lock:
105
  gen = torch.Generator(device="cpu").manual_seed(seed)
106
  result = pipe(**kwargs, generator=gen).images[0]
107
  return result
@@ -146,7 +137,7 @@ def infer(
146
  width = max(256, min(MAX_IMAGE_SIZE, round(int(width) / 8) * 8))
147
  height = max(256, min(MAX_IMAGE_SIZE, round(int(height) / 8) * 8))
148
 
149
- shared_kwargs = dict(
150
  prompt=prompt,
151
  height=height,
152
  width=width,
@@ -154,25 +145,15 @@ def infer(
154
  guidance_scale=guidance_scale,
155
  )
156
  if image_list is not None:
157
- shared_kwargs["image"] = image_list
158
-
159
- with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
160
- future_std = executor.submit(run_pipeline, pipe_standard, pipe_lock_standard, shared_kwargs, seed)
161
- future_small = executor.submit(run_pipeline, pipe_small_decoder, pipe_lock_small, shared_kwargs, seed)
162
-
163
- concurrent.futures.wait(
164
- [future_std, future_small],
165
- return_when=concurrent.futures.ALL_COMPLETED,
166
- )
167
 
168
- out_standard = future_std.result()
169
- out_small = future_small.result()
170
 
171
  gc.collect()
172
  if torch.cuda.is_available():
173
  torch.cuda.empty_cache()
174
 
175
- return out_standard, out_small, seed
176
 
177
 
178
  # --- FastAPI Endpoints ---
@@ -214,33 +195,8 @@ async def download_file(filename: str):
214
  return JSONResponse({"error": "File not found"}, status_code=404)
215
  return FileResponse(path, filename=filename, media_type="image/png")
216
 
217
- @app.get("/api/download-zip")
218
- async def download_zip(std: str, small: str):
219
- """Packages both generated images into a single ZIP file and streams it."""
220
- std_name = Path(std).name
221
- small_name = Path(small).name
222
-
223
- std_path = OUTPUT_DIR / std_name
224
- small_path = OUTPUT_DIR / small_name
225
-
226
- if not std_path.exists() or not small_path.exists():
227
- return JSONResponse({"error": "Generated files not found"}, status_code=404)
228
-
229
- memory_file = io.BytesIO()
230
- with zipfile.ZipFile(memory_file, 'w', zipfile.ZIP_DEFLATED) as zf:
231
- zf.write(std_path, arcname=f"Standard_Decoder_{std_name}")
232
- zf.write(small_path, arcname=f"Small_Decoder_{small_name}")
233
-
234
- memory_file.seek(0)
235
-
236
- return StreamingResponse(
237
- memory_file,
238
- media_type="application/zip",
239
- headers={"Content-Disposition": f"attachment; filename=Flux2_Comparison_{uuid.uuid4().hex[:6]}.zip"}
240
- )
241
-
242
- @app.post("/api/compare")
243
- async def compare_images(
244
  prompt: str = Form(...),
245
  seed: str = Form("0"),
246
  randomize_seed: str = Form("true"),
@@ -264,7 +220,7 @@ async def compare_images(
264
  temp_paths.append(str(temp_path))
265
  image_paths.append(str(temp_path))
266
 
267
- result_std, result_small, used_seed = infer(
268
  prompt=prompt,
269
  image_paths=image_paths,
270
  seed=int(seed),
@@ -275,16 +231,13 @@ async def compare_images(
275
  guidance_scale=float(guidance),
276
  )
277
 
278
- std_filename = save_image(result_std, prefix="std")
279
- small_filename = save_image(result_small, prefix="small")
280
 
281
  return JSONResponse({
282
  "success": True,
283
  "seed": used_seed,
284
- "std_url": f"/download/{std_filename}",
285
- "small_url": f"/download/{small_filename}",
286
- "std_filename": std_filename,
287
- "small_filename": small_filename,
288
  "device": DEVICE_LABEL,
289
  })
290
 
@@ -301,13 +254,12 @@ async def homepage(request: Request):
301
  examples = get_example_items()
302
  examples_json = json.dumps(examples)
303
 
304
- return f"""
305
- <!DOCTYPE html>
306
  <html lang="en">
307
  <head>
308
  <meta charset="UTF-8" />
309
  <meta name="viewport" content="width=device-width, initial-scale=1.0" />
310
- <title>Flux.2-4B-Decoder-Comparator</title>
311
  <link href="https://fonts.googleapis.com/css2?family=Ubuntu:wght@300;400;500;700&display=swap" rel="stylesheet">
312
  <style>
313
  :root {{
@@ -366,13 +318,13 @@ async def homepage(request: Request):
366
  margin: 0;
367
  }}
368
 
369
- /* FIXED LAYOUT GRID */
370
  .layout {{
371
  display: grid;
372
  grid-template-columns: 420px 1fr;
373
  gap: 24px;
374
  align-items: stretch;
375
- height: 650px;
376
  }}
377
 
378
  .panel {{
@@ -446,7 +398,7 @@ async def homepage(request: Request):
446
  }}
447
  .upload-zone input[type="file"] {{ display: none; }}
448
  .upload-text {{ pointer-events: none; color: var(--ub-muted); }}
449
-
450
  .preview-grid {{
451
  display: none;
452
  grid-template-columns: repeat(auto-fill, minmax(70px, 1fr));
@@ -466,7 +418,7 @@ async def homepage(request: Request):
466
  display: flex; align-items: center; justify-content: center;
467
  cursor: pointer; font-size: 12px;
468
  }}
469
-
470
  .add-more-btn {{
471
  display: flex; align-items: center; justify-content: center;
472
  border: 2px dashed var(--ub-muted); border-radius: 4px;
@@ -533,52 +485,29 @@ async def homepage(request: Request):
533
  }}
534
  .action-icon:hover {{ color: var(--ub-orange); }}
535
 
536
- /* SLIDER CONTAINER */
537
- .panel-body-slider {{
538
  flex: 1; display: flex; flex-direction: column;
539
  padding: 0; position: relative;
540
  }}
541
- .slider-stage {{
542
  position: absolute; top: 0; left: 0; right: 0; bottom: 0;
543
  background: #111; overflow: hidden; display: flex;
544
  align-items: center; justify-content: center;
545
  }}
546
- .slider-empty {{ color: var(--ub-muted); text-align: center; z-index: 1; }}
547
-
548
- .slider-img {{
549
- position: absolute; top: 0; left: 0; width: 100%; height: 100%;
550
- object-fit: contain; display: none; user-select: none; -webkit-user-drag: none;
551
- }}
552
- #imgSmall {{ clip-path: inset(0 50% 0 0); }}
553
-
554
- .slider-handle {{
555
- position: absolute; left: 50%; top: 0; bottom: 0;
556
- width: 4px; background: var(--ub-orange); cursor: ew-resize; display: none; z-index: 10;
557
- }}
558
- .slider-handle::after {{
559
- content: '◀ ▶'; position: absolute; top: 50%; left: 50%;
560
- transform: translate(-50%, -50%); width: 40px; height: 30px;
561
- background: var(--ub-orange); color: white; border-radius: 15px;
562
- display: flex; align-items: center; justify-content: center;
563
- font-size: 10px; font-weight: bold; box-shadow: 0 2px 6px rgba(0,0,0,0.5);
564
- }}
565
 
566
- .slider-labels {{
567
- position: absolute; top: 15px; left: 15px; right: 15px;
568
- display: none; justify-content: space-between;
569
- pointer-events: none; z-index: 5;
570
- }}
571
- .badge {{
572
- background: rgba(0,0,0,0.6); color: white; padding: 6px 12px;
573
- border-radius: 20px; font-size: 13px; backdrop-filter: blur(4px);
574
  }}
575
 
576
- /* UPDATED LOADER ANIMATION (Minimalist Single Circle) */
577
  .loader {{
578
- position: absolute; inset: 0;
579
- background: rgba(20, 0, 10, 0.7); /* dark aubergine tint */
580
  backdrop-filter: blur(6px);
581
- display: none; flex-direction: column;
582
  align-items: center; justify-content: center; z-index: 20;
583
  }}
584
  .spinner-single {{
@@ -600,8 +529,17 @@ async def homepage(request: Request):
600
  0%, 100% {{ opacity: 1; }}
601
  50% {{ opacity: 0.5; }}
602
  }}
603
- @keyframes spin {{
604
- to {{ transform: rotate(360deg); }}
 
 
 
 
 
 
 
 
 
605
  }}
606
 
607
  /* Examples */
@@ -621,22 +559,23 @@ async def homepage(request: Request):
621
 
622
  @media (max-width: 900px) {{
623
  .layout {{ grid-template-columns: 1fr; height: auto; }}
624
- .panel-body-slider {{ height: 450px; flex: none; }}
625
- .slider-stage {{ position: relative; height: 100%; }}
626
  }}
627
  </style>
628
  </head>
629
  <body>
630
 
631
- <div class="topbar">Flux.2-4B VAE Decoder Comparator</div>
632
 
633
  <div class="container">
634
  <div class="header-text">
635
- <h1>Standard vs. Small Decoder</h1>
636
- <p>Upload an image, enter a prompt, and use the slider to compare outputs in real-time.</p>
637
  </div>
638
 
639
  <div class="layout">
 
640
  <div class="panel">
641
  <div class="panel-header">Settings</div>
642
  <div class="panel-body-scroll">
@@ -657,7 +596,7 @@ async def homepage(request: Request):
657
  <button class="advanced-toggle" id="advToggle">
658
  <span>Advanced Settings</span> <span class="advanced-icon" id="advIcon">+</span>
659
  </button>
660
-
661
  <div class="advanced-body" id="advBody">
662
  <div class="grid-2">
663
  <div class="form-group">
@@ -695,14 +634,15 @@ async def homepage(request: Request):
695
  </div>
696
  </div>
697
 
698
- <button class="btn btn-primary" id="runBtn">Run Comparison</button>
699
  </div>
700
  </div>
701
 
 
702
  <div class="panel">
703
  <div class="panel-header">
704
- <span>Comparison View</span>
705
- <button id="downloadZipBtn" class="action-icon" title="Download Both Images (ZIP)">
706
  <svg width="22" height="22" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" viewBox="0 0 24 24">
707
  <path d="M21 15v4a2 2 0 0 1-2 2H5a2 2 0 0 1-2-2v-4"></path>
708
  <polyline points="7 10 12 15 17 10"></polyline>
@@ -710,34 +650,29 @@ async def homepage(request: Request):
710
  </svg>
711
  </button>
712
  </div>
713
- <div class="panel-body-slider">
714
- <div class="slider-stage" id="sliderStage">
715
- <div class="slider-empty" id="sliderEmpty">
716
  <svg width="48" height="48" fill="none" stroke="currentColor" stroke-width="1.5" viewBox="0 0 24 24" style="margin-bottom:10px; opacity:0.5;">
717
  <path d="M4 16l4.586-4.586a2 2 0 012.828 0L16 16m-2-2l1.586-1.586a2 2 0 012.828 0L20 14m-6-6h.01M6 20h12a2 2 0 002-2V6a2 2 0 00-2-2H6a2 2 0 00-2 2v12a2 2 0 002 2z"></path>
718
  </svg>
719
- <div>Results will appear here</div>
720
  </div>
721
 
722
- <img id="imgStd" class="slider-img" alt="Standard Decoder" />
723
- <img id="imgSmall" class="slider-img" alt="Small Decoder" />
724
-
725
- <div class="slider-labels" id="sliderLabels">
726
- <div class="badge">Standard Decoder</div>
727
- <div class="badge">Small Decoder</div>
728
- </div>
729
 
730
- <div class="slider-handle" id="sliderHandle"></div>
731
 
732
  <div class="loader" id="loader">
733
  <div class="spinner-single"></div>
734
- <div class="loader-text">Running both decoders...</div>
735
  </div>
736
  </div>
737
  </div>
738
  </div>
739
  </div>
740
 
 
741
  <div class="examples-section">
742
  <h3>Examples</h3>
743
  <div class="examples-grid" id="examplesGrid"></div>
@@ -747,39 +682,31 @@ async def homepage(request: Request):
747
  <script>
748
  const examples = {examples_json};
749
  let filesState = [];
750
- let currentStdFilename = "";
751
- let currentSmallFilename = "";
752
-
753
- // UI Elements
754
- const dropZone = document.getElementById('dropZone');
755
- const fileInput = document.getElementById('fileInput');
756
- const previewGrid = document.getElementById('previewGrid');
757
- const uploadText = document.getElementById('uploadText');
758
- const promptInput = document.getElementById('promptInput');
759
- const runBtn = document.getElementById('runBtn');
760
- const downloadZipBtn = document.getElementById('downloadZipBtn');
761
-
762
- // Status Log
763
- const statusLog = document.getElementById('statusLog');
 
764
 
765
  function logMsg(msg, styleClass="") {{
766
  const div = document.createElement('div');
767
  const timeStr = new Date().toLocaleTimeString('en-US', {{hour12:false}});
768
  div.innerHTML = `<span class="log-time">[${{timeStr}}]</span><span class="${{styleClass}}">${{msg}}</span>`;
769
  statusLog.appendChild(div);
770
- statusLog.scrollTop = statusLog.scrollHeight; // auto-scroll to bottom
771
  }}
772
 
773
- // Slider Elements
774
- const sliderStage = document.getElementById('sliderStage');
775
- const imgStd = document.getElementById('imgStd');
776
- const imgSmall = document.getElementById('imgSmall');
777
- const sliderHandle = document.getElementById('sliderHandle');
778
- const sliderLabels = document.getElementById('sliderLabels');
779
- const sliderEmpty = document.getElementById('sliderEmpty');
780
- const loader = document.getElementById('loader');
781
-
782
- // Advanced Toggle logic (+ / -)
783
  document.getElementById('advToggle').onclick = function() {{
784
  const body = document.getElementById('advBody');
785
  body.classList.toggle('open');
@@ -789,10 +716,10 @@ async def homepage(request: Request):
789
  // --- File Upload Logic ---
790
  function renderPreviews() {{
791
  previewGrid.innerHTML = '';
792
- if(filesState.length > 0) {{
793
  uploadText.style.display = 'none';
794
  previewGrid.style.display = 'grid';
795
-
796
  filesState.forEach((f, i) => {{
797
  const div = document.createElement('div');
798
  div.className = 'thumb';
@@ -805,27 +732,25 @@ async def homepage(request: Request):
805
  div.appendChild(img); div.appendChild(btn);
806
  previewGrid.appendChild(div);
807
  }});
808
-
809
- // Append dynamic + button
810
  const addBtn = document.createElement('div');
811
  addBtn.className = 'add-more-btn';
812
  addBtn.innerHTML = '+';
813
  addBtn.onclick = (e) => {{ e.stopPropagation(); fileInput.click(); }};
814
  previewGrid.appendChild(addBtn);
815
-
816
  }} else {{
817
  uploadText.style.display = 'block';
818
  previewGrid.style.display = 'none';
819
  }}
820
  }}
821
 
822
- dropZone.onclick = (e) => {{ if(e.target === dropZone || e.target === uploadText) fileInput.click(); }};
823
  fileInput.onchange = (e) => {{ filesState.push(...Array.from(e.target.files)); renderPreviews(); fileInput.value=''; }};
824
  dropZone.ondragover = (e) => {{ e.preventDefault(); dropZone.classList.add('dragover'); }};
825
  dropZone.ondragleave = () => dropZone.classList.remove('dragover');
826
  dropZone.ondrop = (e) => {{
827
  e.preventDefault(); dropZone.classList.remove('dragover');
828
- if(e.dataTransfer.files.length) {{ filesState.push(...Array.from(e.dataTransfer.files)); renderPreviews(); }}
829
  }};
830
 
831
  // --- Examples Logic ---
@@ -834,26 +759,23 @@ async def homepage(request: Request):
834
  renderPreviews();
835
  promptInput.value = text;
836
  logMsg("Loading example: " + text, "log-info");
837
-
838
  try {{
839
- for(let i=0; i<urls.length; i++) {{
840
- const res = await fetch(urls[i]);
841
- const blob = await res.blob();
842
- const filename = urls[i].split('/').pop();
843
- filesState.push(new File([blob], filename, {{type: blob.type}}));
844
  }}
845
  renderPreviews();
846
-
847
  window.scrollTo({{top: 0, behavior: 'smooth'}});
848
-
849
  setTimeout(() => {{
850
- logMsg("Example loaded. Starting comparison...", "log-info");
851
- runBtn.click();
852
  }}, 500);
853
-
854
- }} catch (e) {{
855
  logMsg("Failed to load example images.", "log-error");
856
- alert('Failed to load example image.');
857
  }}
858
  }}
859
 
@@ -861,66 +783,39 @@ async def homepage(request: Request):
861
  examples.forEach(ex => {{
862
  const card = document.createElement('div');
863
  card.className = 'ex-card';
864
-
865
  let imgHTML = '';
866
- if(ex.urls.length > 1) {{
867
  imgHTML = `
868
  <div class="ex-card-img-wrap">
869
- <img src="${{ex.urls[0]}}" style="width:50%; border-right:1px solid #000;">
870
- <img src="${{ex.urls[1]}}" style="width:50%;">
871
- </div>
872
- `;
873
  }} else {{
874
  imgHTML = `<div class="ex-card-img-wrap"><img src="${{ex.urls[0]}}" style="width:100%;"></div>`;
875
  }}
876
-
877
  card.innerHTML = `${{imgHTML}}<p>${{ex.prompt}}</p>`;
878
  card.onclick = () => loadExample(ex.urls, ex.prompt);
879
  exGrid.appendChild(card);
880
  }});
881
 
882
- // --- Image Slider Logic ---
883
- let isDragging = false;
884
-
885
- function updateSlider(clientX) {{
886
- const rect = sliderStage.getBoundingClientRect();
887
- let pos = Math.max(0, Math.min(clientX - rect.left, rect.width));
888
- let percent = (pos / rect.width) * 100;
889
-
890
- sliderHandle.style.left = percent + '%';
891
- imgSmall.style.clipPath = `inset(0 ${{100 - percent}}% 0 0)`;
892
- }}
893
-
894
- sliderHandle.addEventListener('mousedown', () => isDragging = true);
895
- window.addEventListener('mouseup', () => isDragging = false);
896
- window.addEventListener('mousemove', (e) => {{
897
- if (!isDragging) return;
898
- updateSlider(e.clientX);
899
- }});
900
-
901
- sliderHandle.addEventListener('touchstart', () => isDragging = true);
902
- window.addEventListener('touchend', () => isDragging = false);
903
- window.addEventListener('touchmove', (e) => {{
904
- if (!isDragging) return;
905
- updateSlider(e.touches[0].clientX);
906
- }});
907
-
908
- // --- Download Zip Logic ---
909
- downloadZipBtn.onclick = () => {{
910
- if(!currentStdFilename || !currentSmallFilename) return;
911
- logMsg("Initiating ZIP download...", "log-info");
912
- window.location.href = `/api/download-zip?std=${{currentStdFilename}}&small=${{currentSmallFilename}}`;
913
  }};
914
 
915
  // --- Form Submission ---
916
  runBtn.onclick = async () => {{
917
  const prompt = promptInput.value.trim();
918
- if(!prompt) {{
919
  logMsg("Validation failed: Prompt is empty.", "log-error");
920
  return alert("Enter a prompt");
921
  }}
922
 
923
- logMsg("Initializing generation sequence...", "log-info");
924
 
925
  const fd = new FormData();
926
  fd.append('prompt', prompt);
@@ -930,51 +825,44 @@ async def homepage(request: Request):
930
  fd.append('height', document.getElementById('height').value);
931
  fd.append('steps', document.getElementById('steps').value);
932
  fd.append('guidance', document.getElementById('guidance').value);
933
-
934
  filesState.forEach(f => fd.append('images', f));
935
 
936
  loader.style.display = 'flex';
937
  runBtn.disabled = true;
938
- downloadZipBtn.style.display = 'none';
 
939
 
940
- logMsg("Sending request to backend. Running both VAE models...", "log-info");
941
 
942
  try {{
943
- const res = await fetch('/api/compare', {{ method: 'POST', body: fd }});
944
  const data = await res.json();
945
-
946
- if(data.success) {{
947
- logMsg(`Success! Inference completed. Used seed: ${{data.seed}}`, "log-success");
948
-
949
- currentStdFilename = data.std_filename;
950
- currentSmallFilename = data.small_filename;
951
-
952
- imgStd.src = data.std_url;
953
- imgSmall.src = data.small_url;
954
-
955
- imgStd.onload = () => {{
956
- sliderEmpty.style.display = 'none';
957
- imgStd.style.display = 'block';
958
- imgSmall.style.display = 'block';
959
- sliderHandle.style.display = 'block';
960
- sliderLabels.style.display = 'flex';
961
- downloadZipBtn.style.display = 'block'; // Reveal download button
962
-
963
- // Reset slider to center
964
- const rect = sliderStage.getBoundingClientRect();
965
- updateSlider(rect.left + rect.width / 2);
966
  }};
967
  }} else {{
968
- logMsg("Error processing request: " + data.error, "log-error");
969
  alert('Error: ' + data.error);
970
  }}
971
- }} catch(e) {{
972
  logMsg("Network or server connection failed.", "log-error");
973
  alert('Failed to connect to server.');
974
  }} finally {{
975
  loader.style.display = 'none';
976
  runBtn.disabled = false;
977
- logMsg("Sequence finished. Ready for next input.", "");
978
  }}
979
  }};
980
  </script>
 
7
  import random
8
  import zipfile
9
  import threading
 
10
  from pathlib import Path
11
  from typing import List, Optional
12
 
 
45
  DEVICE_LABEL = str(device).lower()
46
 
47
  # --- Model Loading ---
 
 
 
 
 
 
 
48
  print("Loading Small Decoder VAE...")
49
  vae_small = AutoencoderKLFlux2.from_pretrained(
50
  "black-forest-labs/FLUX.2-small-decoder",
 
52
  ).to(device)
53
 
54
  print("Loading 4B Distilled model (Small Decoder VAE)...")
55
+ pipe = Flux2KleinPipeline.from_pretrained(
56
  "black-forest-labs/FLUX.2-klein-4B",
57
  vae=vae_small,
58
  torch_dtype=dtype,
59
  ).to(device)
60
+ pipe.enable_model_cpu_offload()
61
 
62
+ pipe_lock = threading.Lock()
 
63
 
64
  # --- Utility Functions ---
65
  def calc_dimensions(pil_img: Image.Image):
 
80
  def parse_and_resize_images(image_paths: List[str], width: int, height: int):
81
  if not image_paths:
82
  return None
83
+
84
  resized = []
85
  for path in image_paths:
86
  try:
 
88
  resized.append(img.resize((width, height), Image.LANCZOS))
89
  except Exception as e:
90
  print(f"Skipping invalid image: {e}")
91
+
92
  return resized if resized else None
93
 
94
+ def run_pipeline(kwargs, seed):
95
+ with pipe_lock:
96
  gen = torch.Generator(device="cpu").manual_seed(seed)
97
  result = pipe(**kwargs, generator=gen).images[0]
98
  return result
 
137
  width = max(256, min(MAX_IMAGE_SIZE, round(int(width) / 8) * 8))
138
  height = max(256, min(MAX_IMAGE_SIZE, round(int(height) / 8) * 8))
139
 
140
+ kwargs = dict(
141
  prompt=prompt,
142
  height=height,
143
  width=width,
 
145
  guidance_scale=guidance_scale,
146
  )
147
  if image_list is not None:
148
+ kwargs["image"] = image_list
 
 
 
 
 
 
 
 
 
149
 
150
+ result = run_pipeline(kwargs, seed)
 
151
 
152
  gc.collect()
153
  if torch.cuda.is_available():
154
  torch.cuda.empty_cache()
155
 
156
+ return result, seed
157
 
158
 
159
  # --- FastAPI Endpoints ---
 
195
  return JSONResponse({"error": "File not found"}, status_code=404)
196
  return FileResponse(path, filename=filename, media_type="image/png")
197
 
198
+ @app.post("/api/generate")
199
+ async def generate_image(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  prompt: str = Form(...),
201
  seed: str = Form("0"),
202
  randomize_seed: str = Form("true"),
 
220
  temp_paths.append(str(temp_path))
221
  image_paths.append(str(temp_path))
222
 
223
+ result, used_seed = infer(
224
  prompt=prompt,
225
  image_paths=image_paths,
226
  seed=int(seed),
 
231
  guidance_scale=float(guidance),
232
  )
233
 
234
+ filename = save_image(result, prefix="output")
 
235
 
236
  return JSONResponse({
237
  "success": True,
238
  "seed": used_seed,
239
+ "url": f"/download/{filename}",
240
+ "filename": filename,
 
 
241
  "device": DEVICE_LABEL,
242
  })
243
 
 
254
  examples = get_example_items()
255
  examples_json = json.dumps(examples)
256
 
257
+ return f"""<!DOCTYPE html>
 
258
  <html lang="en">
259
  <head>
260
  <meta charset="UTF-8" />
261
  <meta name="viewport" content="width=device-width, initial-scale=1.0" />
262
+ <title>Flux.2 Klein — Small Decoder</title>
263
  <link href="https://fonts.googleapis.com/css2?family=Ubuntu:wght@300;400;500;700&display=swap" rel="stylesheet">
264
  <style>
265
  :root {{
 
318
  margin: 0;
319
  }}
320
 
321
+ /* LAYOUT */
322
  .layout {{
323
  display: grid;
324
  grid-template-columns: 420px 1fr;
325
  gap: 24px;
326
  align-items: stretch;
327
+ height: 650px;
328
  }}
329
 
330
  .panel {{
 
398
  }}
399
  .upload-zone input[type="file"] {{ display: none; }}
400
  .upload-text {{ pointer-events: none; color: var(--ub-muted); }}
401
+
402
  .preview-grid {{
403
  display: none;
404
  grid-template-columns: repeat(auto-fill, minmax(70px, 1fr));
 
418
  display: flex; align-items: center; justify-content: center;
419
  cursor: pointer; font-size: 12px;
420
  }}
421
+
422
  .add-more-btn {{
423
  display: flex; align-items: center; justify-content: center;
424
  border: 2px dashed var(--ub-muted); border-radius: 4px;
 
485
  }}
486
  .action-icon:hover {{ color: var(--ub-orange); }}
487
 
488
+ /* OUTPUT PANEL */
489
+ .panel-body-output {{
490
  flex: 1; display: flex; flex-direction: column;
491
  padding: 0; position: relative;
492
  }}
493
+ .output-stage {{
494
  position: absolute; top: 0; left: 0; right: 0; bottom: 0;
495
  background: #111; overflow: hidden; display: flex;
496
  align-items: center; justify-content: center;
497
  }}
498
+ .output-empty {{ color: var(--ub-muted); text-align: center; z-index: 1; }}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
499
 
500
+ .output-img {{
501
+ position: absolute; top: 0; left: 0; width: 100%; height: 100%;
502
+ object-fit: contain; display: none; user-select: none;
 
 
 
 
 
503
  }}
504
 
505
+ /* LOADER */
506
  .loader {{
507
+ position: absolute; inset: 0;
508
+ background: rgba(20, 0, 10, 0.7);
509
  backdrop-filter: blur(6px);
510
+ display: none; flex-direction: column;
511
  align-items: center; justify-content: center; z-index: 20;
512
  }}
513
  .spinner-single {{
 
529
  0%, 100% {{ opacity: 1; }}
530
  50% {{ opacity: 0.5; }}
531
  }}
532
+ @keyframes spin {{
533
+ to {{ transform: rotate(360deg); }}
534
+ }}
535
+
536
+ /* Seed Badge */
537
+ .seed-badge {{
538
+ display: none;
539
+ position: absolute; bottom: 12px; right: 12px;
540
+ background: rgba(0,0,0,0.6); color: #ccc; padding: 5px 10px;
541
+ border-radius: 20px; font-size: 12px; backdrop-filter: blur(4px);
542
+ z-index: 5;
543
  }}
544
 
545
  /* Examples */
 
559
 
560
  @media (max-width: 900px) {{
561
  .layout {{ grid-template-columns: 1fr; height: auto; }}
562
+ .panel-body-output {{ height: 450px; flex: none; }}
563
+ .output-stage {{ position: relative; height: 100%; }}
564
  }}
565
  </style>
566
  </head>
567
  <body>
568
 
569
+ <div class="topbar">Flux.2 Klein 4B Small Decoder</div>
570
 
571
  <div class="container">
572
  <div class="header-text">
573
+ <h1>Flux.2 Klein — Small Decoder VAE</h1>
574
+ <p>Upload an image (optional) and enter a prompt to generate with the 4B distilled model.</p>
575
  </div>
576
 
577
  <div class="layout">
578
+ <!-- Left: Settings Panel -->
579
  <div class="panel">
580
  <div class="panel-header">Settings</div>
581
  <div class="panel-body-scroll">
 
596
  <button class="advanced-toggle" id="advToggle">
597
  <span>Advanced Settings</span> <span class="advanced-icon" id="advIcon">+</span>
598
  </button>
599
+
600
  <div class="advanced-body" id="advBody">
601
  <div class="grid-2">
602
  <div class="form-group">
 
634
  </div>
635
  </div>
636
 
637
+ <button class="btn btn-primary" id="runBtn">Generate</button>
638
  </div>
639
  </div>
640
 
641
+ <!-- Right: Output Panel -->
642
  <div class="panel">
643
  <div class="panel-header">
644
+ <span>Output</span>
645
+ <button id="downloadBtn" class="action-icon" title="Download Image">
646
  <svg width="22" height="22" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" viewBox="0 0 24 24">
647
  <path d="M21 15v4a2 2 0 0 1-2 2H5a2 2 0 0 1-2-2v-4"></path>
648
  <polyline points="7 10 12 15 17 10"></polyline>
 
650
  </svg>
651
  </button>
652
  </div>
653
+ <div class="panel-body-output">
654
+ <div class="output-stage" id="outputStage">
655
+ <div class="output-empty" id="outputEmpty">
656
  <svg width="48" height="48" fill="none" stroke="currentColor" stroke-width="1.5" viewBox="0 0 24 24" style="margin-bottom:10px; opacity:0.5;">
657
  <path d="M4 16l4.586-4.586a2 2 0 012.828 0L16 16m-2-2l1.586-1.586a2 2 0 012.828 0L20 14m-6-6h.01M6 20h12a2 2 0 002-2V6a2 2 0 00-2-2H6a2 2 0 00-2 2v12a2 2 0 002 2z"></path>
658
  </svg>
659
+ <div>Result will appear here</div>
660
  </div>
661
 
662
+ <img id="outputImg" class="output-img" alt="Generated Output" />
 
 
 
 
 
 
663
 
664
+ <div class="seed-badge" id="seedBadge"></div>
665
 
666
  <div class="loader" id="loader">
667
  <div class="spinner-single"></div>
668
+ <div class="loader-text">Generating image...</div>
669
  </div>
670
  </div>
671
  </div>
672
  </div>
673
  </div>
674
 
675
+ <!-- Examples Section -->
676
  <div class="examples-section">
677
  <h3>Examples</h3>
678
  <div class="examples-grid" id="examplesGrid"></div>
 
682
  <script>
683
  const examples = {examples_json};
684
  let filesState = [];
685
+ let currentFilename = "";
686
+
687
+ const dropZone = document.getElementById('dropZone');
688
+ const fileInput = document.getElementById('fileInput');
689
+ const previewGrid = document.getElementById('previewGrid');
690
+ const uploadText = document.getElementById('uploadText');
691
+ const promptInput = document.getElementById('promptInput');
692
+ const runBtn = document.getElementById('runBtn');
693
+ const downloadBtn = document.getElementById('downloadBtn');
694
+ const statusLog = document.getElementById('statusLog');
695
+ const outputStage = document.getElementById('outputStage');
696
+ const outputImg = document.getElementById('outputImg');
697
+ const outputEmpty = document.getElementById('outputEmpty');
698
+ const loader = document.getElementById('loader');
699
+ const seedBadge = document.getElementById('seedBadge');
700
 
701
  function logMsg(msg, styleClass="") {{
702
  const div = document.createElement('div');
703
  const timeStr = new Date().toLocaleTimeString('en-US', {{hour12:false}});
704
  div.innerHTML = `<span class="log-time">[${{timeStr}}]</span><span class="${{styleClass}}">${{msg}}</span>`;
705
  statusLog.appendChild(div);
706
+ statusLog.scrollTop = statusLog.scrollHeight;
707
  }}
708
 
709
+ // --- Advanced Toggle ---
 
 
 
 
 
 
 
 
 
710
  document.getElementById('advToggle').onclick = function() {{
711
  const body = document.getElementById('advBody');
712
  body.classList.toggle('open');
 
716
  // --- File Upload Logic ---
717
  function renderPreviews() {{
718
  previewGrid.innerHTML = '';
719
+ if (filesState.length > 0) {{
720
  uploadText.style.display = 'none';
721
  previewGrid.style.display = 'grid';
722
+
723
  filesState.forEach((f, i) => {{
724
  const div = document.createElement('div');
725
  div.className = 'thumb';
 
732
  div.appendChild(img); div.appendChild(btn);
733
  previewGrid.appendChild(div);
734
  }});
735
+
 
736
  const addBtn = document.createElement('div');
737
  addBtn.className = 'add-more-btn';
738
  addBtn.innerHTML = '+';
739
  addBtn.onclick = (e) => {{ e.stopPropagation(); fileInput.click(); }};
740
  previewGrid.appendChild(addBtn);
 
741
  }} else {{
742
  uploadText.style.display = 'block';
743
  previewGrid.style.display = 'none';
744
  }}
745
  }}
746
 
747
+ dropZone.onclick = (e) => {{ if (e.target === dropZone || e.target === uploadText) fileInput.click(); }};
748
  fileInput.onchange = (e) => {{ filesState.push(...Array.from(e.target.files)); renderPreviews(); fileInput.value=''; }};
749
  dropZone.ondragover = (e) => {{ e.preventDefault(); dropZone.classList.add('dragover'); }};
750
  dropZone.ondragleave = () => dropZone.classList.remove('dragover');
751
  dropZone.ondrop = (e) => {{
752
  e.preventDefault(); dropZone.classList.remove('dragover');
753
+ if (e.dataTransfer.files.length) {{ filesState.push(...Array.from(e.dataTransfer.files)); renderPreviews(); }}
754
  }};
755
 
756
  // --- Examples Logic ---
 
759
  renderPreviews();
760
  promptInput.value = text;
761
  logMsg("Loading example: " + text, "log-info");
762
+
763
  try {{
764
+ for (let i = 0; i < urls.length; i++) {{
765
+ const res = await fetch(urls[i]);
766
+ const blob = await res.blob();
767
+ const filename = urls[i].split('/').pop();
768
+ filesState.push(new File([blob], filename, {{type: blob.type}}));
769
  }}
770
  renderPreviews();
 
771
  window.scrollTo({{top: 0, behavior: 'smooth'}});
 
772
  setTimeout(() => {{
773
+ logMsg("Example loaded. Starting generation...", "log-info");
774
+ runBtn.click();
775
  }}, 500);
776
+ }} catch (e) {{
 
777
  logMsg("Failed to load example images.", "log-error");
778
+ alert('Failed to load example image.');
779
  }}
780
  }}
781
 
 
783
  examples.forEach(ex => {{
784
  const card = document.createElement('div');
785
  card.className = 'ex-card';
786
+
787
  let imgHTML = '';
788
+ if (ex.urls.length > 1) {{
789
  imgHTML = `
790
  <div class="ex-card-img-wrap">
791
+ <img src="${{ex.urls[0]}}" style="width:50%; border-right:1px solid #000;">
792
+ <img src="${{ex.urls[1]}}" style="width:50%;">
793
+ </div>`;
 
794
  }} else {{
795
  imgHTML = `<div class="ex-card-img-wrap"><img src="${{ex.urls[0]}}" style="width:100%;"></div>`;
796
  }}
797
+
798
  card.innerHTML = `${{imgHTML}}<p>${{ex.prompt}}</p>`;
799
  card.onclick = () => loadExample(ex.urls, ex.prompt);
800
  exGrid.appendChild(card);
801
  }});
802
 
803
+ // --- Download Logic ---
804
+ downloadBtn.onclick = () => {{
805
+ if (!currentFilename) return;
806
+ logMsg("Downloading image...", "log-info");
807
+ window.location.href = `/download/${{currentFilename}}`;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
808
  }};
809
 
810
  // --- Form Submission ---
811
  runBtn.onclick = async () => {{
812
  const prompt = promptInput.value.trim();
813
+ if (!prompt) {{
814
  logMsg("Validation failed: Prompt is empty.", "log-error");
815
  return alert("Enter a prompt");
816
  }}
817
 
818
+ logMsg("Initializing generation...", "log-info");
819
 
820
  const fd = new FormData();
821
  fd.append('prompt', prompt);
 
825
  fd.append('height', document.getElementById('height').value);
826
  fd.append('steps', document.getElementById('steps').value);
827
  fd.append('guidance', document.getElementById('guidance').value);
828
+
829
  filesState.forEach(f => fd.append('images', f));
830
 
831
  loader.style.display = 'flex';
832
  runBtn.disabled = true;
833
+ downloadBtn.style.display = 'none';
834
+ seedBadge.style.display = 'none';
835
 
836
+ logMsg("Sending request... Running Small Decoder VAE model.", "log-info");
837
 
838
  try {{
839
+ const res = await fetch('/api/generate', {{ method: 'POST', body: fd }});
840
  const data = await res.json();
841
+
842
+ if (data.success) {{
843
+ logMsg(`Success! Used seed: ${{data.seed}}`, "log-success");
844
+
845
+ currentFilename = data.filename;
846
+ outputImg.src = data.url;
847
+
848
+ outputImg.onload = () => {{
849
+ outputEmpty.style.display = 'none';
850
+ outputImg.style.display = 'block';
851
+ downloadBtn.style.display = 'block';
852
+ seedBadge.innerText = `Seed: ${{data.seed}}`;
853
+ seedBadge.style.display = 'block';
 
 
 
 
 
 
 
 
854
  }};
855
  }} else {{
856
+ logMsg("Error: " + data.error, "log-error");
857
  alert('Error: ' + data.error);
858
  }}
859
+ }} catch (e) {{
860
  logMsg("Network or server connection failed.", "log-error");
861
  alert('Failed to connect to server.');
862
  }} finally {{
863
  loader.style.display = 'none';
864
  runBtn.disabled = false;
865
+ logMsg("Ready for next input.", "");
866
  }}
867
  }};
868
  </script>