cocoat commited on
Commit
8d52761
·
verified ·
1 Parent(s): 9a10e39

Update app.py

Browse files

import gradio as gr
import torch
import random
import numpy as np
import datetime

# 履歴保存
from huggingface_hub import HfApi
from huggingface_hub import login

import os
# HF_TOKEN 環境変数からトークンを明示的に読み込む
hf_token_value = os.getenv("HF_TOKEN")

if hf_token_value:
api = HfApi(token=hf_token_value)
print("token ok.")
else:
# トークンが設定されていない場合の警告と代替処理
print("HF_TOKEN error")
api = HfApi() # トークンなしで初期化

# 画像をアップロードするリポジトリID
HF_REPO_ID = "cocoat/images"

# 公開用画像をアップロードするリポジトリID
PUBLIC_REPO_ID = "cocoat/opendata"

# Space内で画像を保存するディレクトリ
SPACE_IMAGE_DIR = "generated_images"
os.makedirs(SPACE_IMAGE_DIR, exist_ok=True)

# 公開リポジトリの画像ディレクトリ
PUBLIC_IMAGE_DIR = "generated_images"
os.makedirs(PUBLIC_IMAGE_DIR, exist_ok=True)

# 履歴ファイルを定義
HISTORY_FILE = "history/generation_history_coamixXL3.txt"

# 履歴をロードする関数
import os
import requests
def load_history():
history_data = []
hf_raw_file_url = f"https://huggingface.co/datasets/{HF_REPO_ID}/raw/main/{HISTORY_FILE}"
headers = {}
if hf_token_value:
headers["Authorization"] = f"Bearer {hf_token_value}"

try:
response = requests.get(hf_raw_file_url, headers=headers)
response.raise_for_status()

loaded_hub_paths = set() # 重複ロードを防ぐため

for line in response.text.splitlines():
parts = line.strip().split("|||")
if len(parts) == 2:
image_path_in_repo = parts[0]
caption = parts[1]
# 公開リポジトリの画像URLを生成
hub_image_url = f"https://huggingface.co/datasets/{PUBLIC_REPO_ID}/resolve/main/{image_path_in_repo}"

history_data.append((image_path_in_repo, caption, hub_image_url))
loaded_hub_paths.add(image_path_in_repo)
print(f"History loaded from Hub and matched with Space images: {len(history_data)} entries.")
except requests.exceptions.RequestException as e:
print(f"Error loading history from Hub via raw URL: {e}. Starting with empty history.")
except Exception as e:
print(f"An unexpected error occurred while parsing history: {e}. Starting with empty history.")

return history_data[:10]


# 履歴を初期化時にロード (修正された load_history を使用)
history = load_history()


# 履歴を初期化時にロード
history = load_history()

from PIL import Image
from diffusers import (
StableDiffusionXLPipeline,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler
)
from huggingface_hub import hf_hub_download, HfApi

# デバイスと型の設定
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
MAX_SEED = np.iinfo(np.int32).max
MAX_SIZE = 2048

# モデルファイルのダウンロード
model_path = hf_hub_download(
repo_id="cocoat/cocoamix",
filename="recocoamixXL3_coamixXL3.safetensors"
)

# パイプライン構築
pipe = StableDiffusionXLPipeline.from_single_file(
model_path,
torch_dtype=torch_dtype,
use_safetensors=True
).to(device)

# スケジューラ設定
euler_scheduler = EulerAncestralDiscreteScheduler.from_config(
pipe.scheduler.config,
use_karras_sigmas=True
)
dpm_scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.scheduler = euler_scheduler

def upload_image_to_hub(image_pil, prompt_text):
# ファイル名を生成(タイムスタンプとプロンプトの一部)
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
# プロンプトから安全なファイル名の一部を生成
# safe_prompt = "".join(c for c in prompt_text if c.isalnum() or c in (' ', '.', '_')).replace(' ', '_')[:30]
filename = f"image_{timestamp}.png"
filepath = f"temp_{filename}"
image_pil.save(filepath)

# Hubにアップロード
try:
# リポジトリ内にディレクトリを作成する場合は path_in_repo を使う
path_in_repo = f"generated_images/{filename}"
api.upload_file(
path_or_fileobj=filepath,
path_in_repo=path_in_repo,
repo_id=PUBLIC_REPO_ID, # 公開用リポジトリに変更
repo_type="dataset",
)

# アップロードされたファイルのURLを構築する(PUBLIC_REPO_IDを使用)
uploaded_file_url = f"https://huggingface.co/datasets/{PUBLIC_REPO_ID}/resolve/main/{path_in_repo}"
print(f"Uploaded {filepath} to {uploaded_file_url}")

# 公開リポジトリの古いファイルを削除するロジック
current_files = api.list_repo_files(repo_id=PUBLIC_REPO_ID, repo_type="dataset")
# PUBLIC_IMAGE_DIR (generated_images) 以下のpngファイルを抽出し、新しいものからソート
generated_images_in_public = sorted([f for f in current_files if f.startswith(PUBLIC_IMAGE_DIR) and f.endswith('.png')], reverse=True)

# 10枚を超える場合、古いファイルを削除
if len(generated_images_in_public) > 10:
files_to_delete = generated_images_in_public[10:]
for file_to_delete in files_to_delete:
try:
api.delete_file(
path_in_repo=file_to_delete,
repo_id=PUBLIC_REPO_ID,
repo_type="dataset",
commit_message=f"Delete old image: {file_to_delete}"
)
print(f"Deleted old public image: {file_to_delete}")
except Exception as del_e:
print(f"Error deleting old public image {file_to_delete}: {del_e}")
return uploaded_file_url, path_in_repo
except Exception as e:
print(f"Error uploading image to Hub: {e}")
return None, None
finally:
# 一時ファイルを削除
if os.path.exists(filepath):
os.remove(filepath)
print(f"一時画像ファイル {filepath} を削除しました。")

def upload_image_to_private_hub(image_pil, prompt_text):
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
filename = f"image_{timestamp}.png"
filepath = f"temp_private_{filename}" # 一時ファイル名が重複しないように private を追加
image_pil.save(filepath)

try:
path_in_repo = f"generated_images/{filename}"
api.upload_file(
path_or_fileobj=filepath,
path_in_repo=path_in_repo,
repo_id=HF_REPO_ID, # 非公開リポジトリ
repo_type="dataset",
)
print(f"Uploaded {filepath} to private Hub: {path_in_repo}")
return path_in_repo # 履歴ファイルに記録するリポジトリ内パスを返す
except Exception as e:
print(f"Error uploading image to private Hub: {e}")
return None
finally:
if os.path.exists(filepath):
os.remove(filepath)
print(f"一時プライベート画像ファイル {filepath} を削除しました。")

def make_html_table(caption):
formatted_caption = caption.replace("|-|", "\n")
rows = formatted_caption.split("\n")
html = '<table style="width:100%;border-collapse:collapse;background:#fffaf1;color:#000">'
for row in rows:
if ": " in row:
key, val = row.split(": ", 1)
html += (
# f'{key}: {val}\n'
f'<tr><th style="text-align:left;border:1px solid #ddd;padding:4px;">{key}</th>'
f'<td style="border:1px solid #ddd;padding:4px;">{val}</td></tr>'
)
html += '</table>'
return html

def create_dummy_image(width=512, height=512, alpha=0):
return Image.new("RGBA", (width, height), (0, 0, 0, alpha))

def update_history_tables_on_select(evt: gr.SelectData):
if evt.index is not None and 0 <= evt.index < len(history):
selected_caption = history[evt.index][1]
return make_html_table(selected_caption)
return ""

def update_history():
tables_html = "".join(
f'<div style="margin-bottom:12px">{make_html_table(item[1])}</div>'
for item in history
)
return tables_html

def infer(prompt, neg, seed, rand, w, h, cfg, steps, scheduler_type,
progress=gr.Progress(track_tqdm=True)):
if rand:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=device).manual_seed(seed)

pipe.scheduler = euler_scheduler if scheduler_type == "Euler Ancestral" else dpm_scheduler
pipe.scheduler.set_timesteps(steps)

def _callback(pipeline, step_idx, timestep, callback_kwargs):
progress(step_idx / steps, desc=f"Step {step_idx}/{steps}")
return callback_kwargs

output = pipe(
prompt=prompt,
negative_prompt=neg or None,
guidance_scale=cfg,
num_inference_steps=steps,
width=w,
height=h,
generator=generator,
callback_on_step_end=_callback
)
img = output.images[0]

caption_text = (
f"Prompt: {prompt}\n"
f"Negative: {neg or 'None'}\n"
f"Seed: {seed}\n"
f"Size: {w}×{h}\n"
f"CFG: {cfg}\n"
f"Steps: {steps}\n"
f"Scheduler: {scheduler_type}"
)

caption_text_for_history = caption_text.replace("\n", "|-|").strip()
# 画像をHubにアップロードし、そのURLとリポジトリ内パスを取得
# 公開用リポジトリにアップロード
uploaded_image_url, path_in_public_repo_for_history = upload_image_to_hub(img, prompt)
# 非公開リポジトリにアップロード
path_in_private_repo_for_history = upload_image_to_private_hub(img, prompt)

# 履歴を更新
global history
# Hubへのアップロードが成功した場合のみ履歴に追加
# historyリストには (非公開リポジトリのパス, キャプション, 公開リポジトリのURL)

Files changed (1) hide show
  1. app.py +93 -87
app.py CHANGED
@@ -71,11 +71,6 @@ def load_history():
71
 
72
  return history_data[:10]
73
 
74
-
75
- # 履歴を初期化時にロード (修正された load_history を使用)
76
- history = load_history()
77
-
78
-
79
  # 履歴を初期化時にロード
80
  history = load_history()
81
 
@@ -162,10 +157,6 @@ def upload_image_to_hub(image_pil, prompt_text):
162
  print(f"Error uploading image to Hub: {e}")
163
  return None, None
164
  finally:
165
- # 一時ファイルを削除
166
- if os.path.exists(filepath):
167
- os.remove(filepath)
168
- print(f"一時画像ファイル {filepath} を削除しました。")
169
 
170
  def upload_image_to_private_hub(image_pil, prompt_text):
171
  timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
@@ -187,9 +178,6 @@ def upload_image_to_private_hub(image_pil, prompt_text):
187
  print(f"Error uploading image to private Hub: {e}")
188
  return None
189
  finally:
190
- if os.path.exists(filepath):
191
- os.remove(filepath)
192
- print(f"一時プライベート画像ファイル {filepath} を削除しました。")
193
 
194
  def make_html_table(caption):
195
  formatted_caption = caption.replace("|-|", "\n")
@@ -224,88 +212,103 @@ def update_history():
224
 
225
  def infer(prompt, neg, seed, rand, w, h, cfg, steps, scheduler_type,
226
  progress=gr.Progress(track_tqdm=True)):
227
- if rand:
228
- seed = random.randint(0, MAX_SEED)
229
- generator = torch.Generator(device=device).manual_seed(seed)
230
-
231
- pipe.scheduler = euler_scheduler if scheduler_type == "Euler Ancestral" else dpm_scheduler
232
- pipe.scheduler.set_timesteps(steps)
233
-
234
- def _callback(pipeline, step_idx, timestep, callback_kwargs):
235
- progress(step_idx / steps, desc=f"Step {step_idx}/{steps}")
236
- return callback_kwargs
237
-
238
- output = pipe(
239
- prompt=prompt,
240
- negative_prompt=neg or None,
241
- guidance_scale=cfg,
242
- num_inference_steps=steps,
243
- width=w,
244
- height=h,
245
- generator=generator,
246
- callback_on_step_end=_callback
247
- )
248
- img = output.images[0]
249
-
250
- caption_text = (
251
- f"Prompt: {prompt}\n"
252
- f"Negative: {neg or 'None'}\n"
253
- f"Seed: {seed}\n"
254
- f"Size: {w}×{h}\n"
255
- f"CFG: {cfg}\n"
256
- f"Steps: {steps}\n"
257
- f"Scheduler: {scheduler_type}"
258
- )
259
-
260
- caption_text_for_history = caption_text.replace("\n", "|-|").strip()
261
- # 画像をHubにアップロードし、そのURLとリポジトリ内パスを取得
262
- # 公開用リポジトリにアップロード
263
- uploaded_image_url, path_in_public_repo_for_history = upload_image_to_hub(img, prompt)
264
- # 非公開リポジトリにアップロード
265
- path_in_private_repo_for_history = upload_image_to_private_hub(img, prompt)
266
-
267
- # 履歴を更新
268
- global history
269
- # Hubへのアップロードが成功した場合のみ履歴に追加
270
- # historyリストには (非公開リポジトリのパス, キャプション, 公開リポジトリのURL) の形式で保存
271
- if path_in_private_repo_for_history and uploaded_image_url:
272
- history.insert(0, (path_in_private_repo_for_history, caption_text_for_history, uploaded_image_url))
273
- else:
274
- print(f"Skipping history update due to failed Hub upload.")
275
-
276
- history_max_items = 10
277
- if len(history) > history_max_items:
278
- history.pop()
279
-
280
- # 履歴ファイルを更新し、Hubにアップロードする
281
- temp_history_filepath = "temp_history.txt"
282
- with open(temp_history_filepath, "w", encoding="utf-8") as f:
283
- for img_path_in_repo, cap_text, _ in history:
284
- f.write(f"{img_path_in_repo}|||{cap_text}\n")
285
 
286
  try:
287
- api.upload_file(
288
- path_or_fileobj=temp_history_filepath,
289
- path_in_repo=HISTORY_FILE,
290
- repo_id=HF_REPO_ID,
291
- repo_type="dataset",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  )
293
- print(f"History file '{HISTORY_FILE}' updated on Hugging Face Hub.")
294
- except Exception as e:
295
- print(f"Error updating history file on Hub: {e}")
296
- finally:
297
- if os.path.exists(temp_history_filepath):
298
- os.remove(temp_history_filepath)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
- progress(1.0, desc="Done!")
301
 
302
- # ギャラリー表示用のアイテムリストを生成(Hub上のURLを使用)
303
- gallery_items = [(item[2], item[1].replace("|-|", "\n")) for item in history]
304
 
305
- processed_img, processed_gallery_items = process_image(img, gallery_items)
306
 
307
- latest_caption_table = make_html_table(caption_text)
308
- return processed_img, processed_gallery_items, latest_caption_table
 
 
 
 
 
 
309
 
310
  import gc
311
  import torch
@@ -373,6 +376,9 @@ html, .gradio-container, .dark, .dark * {
373
  word-wrap:break-word !important;
374
  display: none !important;
375
  }
 
 
 
376
  .gradio-spinner { display: none !important; }
377
  #custom-loader {
378
  align-items: center;
 
71
 
72
  return history_data[:10]
73
 
 
 
 
 
 
74
  # 履歴を初期化時にロード
75
  history = load_history()
76
 
 
157
  print(f"Error uploading image to Hub: {e}")
158
  return None, None
159
  finally:
 
 
 
 
160
 
161
  def upload_image_to_private_hub(image_pil, prompt_text):
162
  timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
 
178
  print(f"Error uploading image to private Hub: {e}")
179
  return None
180
  finally:
 
 
 
181
 
182
  def make_html_table(caption):
183
  formatted_caption = caption.replace("|-|", "\n")
 
212
 
213
  def infer(prompt, neg, seed, rand, w, h, cfg, steps, scheduler_type,
214
  progress=gr.Progress(track_tqdm=True)):
215
+ timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
216
+ filename = f"image_{timestamp}.png"
217
+ filepath = f"temp_{filename}" # ここで一時ファイルのパスを定義
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
  try:
220
+
221
+ if rand:
222
+ seed = random.randint(0, MAX_SEED)
223
+ generator = torch.Generator(device=device).manual_seed(seed)
224
+
225
+ pipe.scheduler = euler_scheduler if scheduler_type == "Euler Ancestral" else dpm_scheduler
226
+ pipe.scheduler.set_timesteps(steps)
227
+
228
+ def _callback(pipeline, step_idx, timestep, callback_kwargs):
229
+ progress(step_idx / steps, desc=f"Step {step_idx}/{steps}")
230
+ return callback_kwargs
231
+
232
+ output = pipe(
233
+ prompt=prompt,
234
+ negative_prompt=neg or None,
235
+ guidance_scale=cfg,
236
+ num_inference_steps=steps,
237
+ width=w,
238
+ height=h,
239
+ generator=generator,
240
+ callback_on_step_end=_callback
241
  )
242
+ img = output.images[0]
243
+
244
+ img.save(filepath)
245
+
246
+ caption_text = (
247
+ f"Prompt: {prompt}\n"
248
+ f"Negative: {neg or 'None'}\n"
249
+ f"Seed: {seed}\n"
250
+ f"Size: {w}×{h}\n"
251
+ f"CFG: {cfg}\n"
252
+ f"Steps: {steps}\n"
253
+ f"Scheduler: {scheduler_type}"
254
+ )
255
+
256
+ caption_text_for_history = caption_text.replace("\n", "|-|").strip()
257
+ # 画像をHubにアップロードし、そのURLとリポジトリ内パスを取得
258
+ # 公開用リポジトリにアップロード
259
+ uploaded_image_url, path_in_public_repo_for_history = upload_image_to_hub(filepath, filename)
260
+ # 非公開リポジトリにアップロード
261
+ path_in_private_repo_for_history = upload_image_to_private_hub(filepath, filename)
262
+
263
+
264
+ # 履歴を更新
265
+ global history
266
+ # Hubへのアップロードが成功した場合のみ履歴に追加
267
+ # historyリストには (非公開リポジトリのパス, キャプション, 公開リポジトリのURL) の形式で保存
268
+ if path_in_private_repo_for_history and uploaded_image_url:
269
+ history.insert(0, (path_in_private_repo_for_history, caption_text_for_history, uploaded_image_url))
270
+ else:
271
+ print(f"Skipping history update due to failed Hub upload.")
272
+
273
+ history_max_items = 10
274
+ if len(history) > history_max_items:
275
+ history.pop()
276
+
277
+ # 履歴ファイルを更新し、Hubにアップロードする
278
+ temp_history_filepath = "temp_history.txt"
279
+ with open(temp_history_filepath, "w", encoding="utf-8") as f:
280
+ for img_path_in_repo, cap_text, _ in history:
281
+ f.write(f"{img_path_in_repo}|||{cap_text}\n")
282
+
283
+ try:
284
+ api.upload_file(
285
+ path_or_fileobj=temp_history_filepath,
286
+ path_in_repo=HISTORY_FILE,
287
+ repo_id=HF_REPO_ID,
288
+ repo_type="dataset",
289
+ )
290
+ print(f"History file '{HISTORY_FILE}' updated on Hugging Face Hub.")
291
+ except Exception as e:
292
+ print(f"Error updating history file on Hub: {e}")
293
+ finally:
294
+ if os.path.exists(temp_history_filepath):
295
+ os.remove(temp_history_filepath)
296
 
297
+ progress(1.0, desc="Done!")
298
 
299
+ # ギャラリー表示用のアイテムリストを生成(Hub上のURLを使用)
300
+ gallery_items = [(item[2], item[1].replace("|-|", "\n")) for item in history]
301
 
302
+ processed_img, processed_gallery_items = process_image(img, gallery_items)
303
 
304
+ latest_caption_table = make_html_table(caption_text)
305
+ return processed_img, processed_gallery_items, latest_caption_table
306
+
307
+ finally:
308
+ # infer関数の最後に一時ファイルを削除
309
+ if os.path.exists(filepath):
310
+ os.remove(filepath)
311
+ print(f"生成画像の一時ファイル {filepath} を削除しました。")
312
 
313
  import gc
314
  import torch
 
376
  word-wrap:break-word !important;
377
  display: none !important;
378
  }
379
+ .caption.svelte-842rpi.svelte-842rpi{
380
+ display: none !important;
381
+ }
382
  .gradio-spinner { display: none !important; }
383
  #custom-loader {
384
  align-items: center;