Upload 15 files
Browse files- inference_drag.py +111 -0
- output_drag_lora/checkpoint-200/model.safetensors +3 -0
- output_drag_lora/checkpoint-200/optimizer.bin +3 -0
- output_drag_lora/checkpoint-200/random_states_0.pkl +3 -0
- output_drag_lora/checkpoint-200/scaler.pt +3 -0
- output_drag_lora/checkpoint-400/model.safetensors +3 -0
- output_drag_lora/checkpoint-400/optimizer.bin +3 -0
- output_drag_lora/checkpoint-400/random_states_0.pkl +3 -0
- output_drag_lora/checkpoint-400/scaler.pt +3 -0
- output_drag_lora/checkpoint-600/model.safetensors +3 -0
- output_drag_lora/checkpoint-600/optimizer.bin +3 -0
- output_drag_lora/checkpoint-600/random_states_0.pkl +3 -0
- output_drag_lora/checkpoint-600/scaler.pt +3 -0
- output_drag_lora/pytorch_lora_weights.safetensors +3 -0
- requirements.txt +0 -0
inference_drag.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler
|
| 3 |
+
from diffusers.utils import export_to_gif
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
# ==========================================
|
| 10 |
+
# 1. ์ค์ (ํ์ต ๋์ ๋ง์ถฐ์ผ ์ ๋์ต๋๋ค)
|
| 11 |
+
# ==========================================
|
| 12 |
+
BASE_MODEL = "runwayml/stable-diffusion-v1-5"
|
| 13 |
+
MOTION_ADAPTER = "guoyww/animatediff-motion-adapter-v1-5-2"
|
| 14 |
+
LORA_PATH = "output_drag_lora" # ๋ฐฉ๊ธ ํ์ต ๋๋ ํด๋
|
| 15 |
+
TEST_IMAGE_PATH = "test_input.png" # ์ค๋นํ ์ด๋ฏธ์ง
|
| 16 |
+
|
| 17 |
+
# ํ์ต ๋ ์ผ๋ ํ๋กฌํํธ ๊ทธ๋๋ก ์ฌ์ฉ (์ค์!)
|
| 18 |
+
PROMPT = "timelapse clouds moving in the sky, cinematic, high quality, 4k"
|
| 19 |
+
NEGATIVE_PROMPT = "bad quality, worst quality, blurry, low resolution, distortion, watermark"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# ==========================================
|
| 23 |
+
# 2. ์์ ๋ณด์ ์ ์ฅ ํจ์ (ํ๋์ ๋ฐฉ์ง)
|
| 24 |
+
# ==========================================
|
| 25 |
+
def save_video_fixed(frames, path, fps=8):
|
| 26 |
+
height, width, _ = np.array(frames[0]).shape
|
| 27 |
+
# OpenCV ๋น๋์ค ์์ฑ๊ธฐ
|
| 28 |
+
out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
|
| 29 |
+
|
| 30 |
+
for frame in frames:
|
| 31 |
+
# PIL -> Numpy ๋ณํ
|
| 32 |
+
img_np = np.array(frame)
|
| 33 |
+
# RGB -> BGR ๋ณํ (์ด๊ฒ ํต์ฌ!)
|
| 34 |
+
img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
| 35 |
+
out.write(img_bgr)
|
| 36 |
+
|
| 37 |
+
out.release()
|
| 38 |
+
print(f"โจ ์์ ์ ์ฅ ์๋ฃ: {path}")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# ==========================================
|
| 42 |
+
# 3. ๋ฉ์ธ ์คํ ํจ์
|
| 43 |
+
# ==========================================
|
| 44 |
+
def main():
|
| 45 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 46 |
+
print(f"Using device: {device}")
|
| 47 |
+
|
| 48 |
+
# 1. ๋ชจ๋ธ ๋ก๋
|
| 49 |
+
print("Loading Base Model...")
|
| 50 |
+
adapter = MotionAdapter.from_pretrained(MOTION_ADAPTER)
|
| 51 |
+
pipe = AnimateDiffPipeline.from_pretrained(
|
| 52 |
+
BASE_MODEL,
|
| 53 |
+
motion_adapter=adapter,
|
| 54 |
+
torch_dtype=torch.float16
|
| 55 |
+
).to(device)
|
| 56 |
+
|
| 57 |
+
# ์ค์ผ์ค๋ฌ ์ค์
|
| 58 |
+
pipe.scheduler = DDIMScheduler.from_pretrained(
|
| 59 |
+
BASE_MODEL,
|
| 60 |
+
subfolder="scheduler",
|
| 61 |
+
clip_sample=False,
|
| 62 |
+
timestep_spacing="linspace",
|
| 63 |
+
beta_schedule="linear",
|
| 64 |
+
steps_offset=1
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# 2. ํ์ตํ LoRA ๋ถ๋ฌ์ค๊ธฐ (์ฑ์ ํ ํ์ธ)
|
| 68 |
+
print(f"Loading LoRA from {LORA_PATH}...")
|
| 69 |
+
try:
|
| 70 |
+
pipe.unet.load_attn_procs(LORA_PATH)
|
| 71 |
+
print("โ
LoRA Load Success!")
|
| 72 |
+
except Exception as e:
|
| 73 |
+
print(f"โ LoRA Load Failed: {e}")
|
| 74 |
+
return
|
| 75 |
+
|
| 76 |
+
# 3. ์ด๋ฏธ์ง ์ค๋น
|
| 77 |
+
if not os.path.exists(TEST_IMAGE_PATH):
|
| 78 |
+
print(f"โ ๏ธ {TEST_IMAGE_PATH}๊ฐ ์์ต๋๋ค! ๊ฒ์ ํ๋ฉด์ผ๋ก ํ
์คํธํฉ๋๋ค.")
|
| 79 |
+
input_image = Image.new('RGB', (256, 256), color='black')
|
| 80 |
+
else:
|
| 81 |
+
input_image = Image.open(TEST_IMAGE_PATH).convert("RGB")
|
| 82 |
+
input_image = input_image.resize((256, 256)) # ํ์ต ํด์๋ ๋ง์ถค
|
| 83 |
+
|
| 84 |
+
# 4. ์์ ์์ฑ (Inference)
|
| 85 |
+
print("Generating Video... (์ฝ 1๋ถ ์์)")
|
| 86 |
+
|
| 87 |
+
# ์๋ ๊ณ ์ (๋งค๋ฒ ๋๊ฐ์ด ์ ๋์ค๊ฒ ํ๊ธฐ ์ํด)
|
| 88 |
+
generator = torch.Generator(device=device).manual_seed(42)
|
| 89 |
+
|
| 90 |
+
output = pipe(
|
| 91 |
+
prompt=PROMPT,
|
| 92 |
+
negative_prompt=NEGATIVE_PROMPT,
|
| 93 |
+
num_frames=16, # 2์ด ์์
|
| 94 |
+
guidance_scale=7.5,
|
| 95 |
+
num_inference_steps=25, # 25๋ฒ๋ง ๊ทธ๋ ค๋ ์ถฉ๋ถ
|
| 96 |
+
generator=generator,
|
| 97 |
+
width=256,
|
| 98 |
+
height=256
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
frames = output.frames[0]
|
| 102 |
+
|
| 103 |
+
# 5. ์ ์ฅ (GIF + MP4)
|
| 104 |
+
export_to_gif(frames, "final_result.gif")
|
| 105 |
+
save_video_fixed(frames, "final_result.mp4", fps=8)
|
| 106 |
+
|
| 107 |
+
print("๐ ๋ชจ๋ ์์
์๋ฃ! 'final_result.mp4'๋ฅผ ํ์ธํ์ธ์.")
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
if __name__ == "__main__":
|
| 111 |
+
main()
|
output_drag_lora/checkpoint-200/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3dc451b5b80d8f93d87dc2a4804e24ed2ca8d950dcae1bd8973fdf7520fab237
|
| 3 |
+
size 5269637960
|
output_drag_lora/checkpoint-200/optimizer.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bd45978dd2af53b0e7fa1c34949f65ec801a87fb3d75b19420692cc61ef3b8ef
|
| 3 |
+
size 32586114
|
output_drag_lora/checkpoint-200/random_states_0.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c9fe4cbbb6dd28903cee54423e7adeb58305fed600351ad550648a40b914db57
|
| 3 |
+
size 14344
|
output_drag_lora/checkpoint-200/scaler.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0c5e66cd0701a17acd6a6484ee4019ae1d960fac369fffe82987ea5dd7915828
|
| 3 |
+
size 988
|
output_drag_lora/checkpoint-400/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:55badd86b56c1084d6ff85506e77c95303c35a50ea4c27a476ee9ad1faafc05d
|
| 3 |
+
size 5269637960
|
output_drag_lora/checkpoint-400/optimizer.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d0a7048d044782e7341cbfb89b12b6d552d4d9914bfc2d2750c0f7ae9a73748c
|
| 3 |
+
size 32586114
|
output_drag_lora/checkpoint-400/random_states_0.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c24300fc53e165ac7325140a2cfe55dcd85aa244509ddf785381b83bc638c02b
|
| 3 |
+
size 14344
|
output_drag_lora/checkpoint-400/scaler.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2ed35716f7a9275082254cfbe153b60686b76c261de984a9e4d37dfb27c8ef24
|
| 3 |
+
size 988
|
output_drag_lora/checkpoint-600/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:52a22baa43787af67caba82823f0553d510813387a7e5d801e872fff37a2b781
|
| 3 |
+
size 5269637960
|
output_drag_lora/checkpoint-600/optimizer.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:92514ffaa22d74392ed8688a8a0535d92c1be145410e80ef12f44314a5465f24
|
| 3 |
+
size 32586114
|
output_drag_lora/checkpoint-600/random_states_0.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b68c4c937eb2a1de34c6c4c7a4892625c0541036db3d34daa8391c3bcc16dc9f
|
| 3 |
+
size 14344
|
output_drag_lora/checkpoint-600/scaler.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:30cc72c17790183d209119d8c6a06b1b84e2e99ea44df74f41e2a84800dba3b3
|
| 3 |
+
size 988
|
output_drag_lora/pytorch_lora_weights.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7e6d416718ea78e54b1921d3ca3df89c559743b712e8f1ac61253321d84229f5
|
| 3 |
+
size 16128712
|
requirements.txt
ADDED
|
Binary file (4.74 kB). View file
|
|
|