namin72 commited on
Commit
629f4b0
ยท
verified ยท
1 Parent(s): 34a9859

Upload 15 files

Browse files
inference_drag.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler
3
+ from diffusers.utils import export_to_gif
4
+ from PIL import Image
5
+ import cv2
6
+ import numpy as np
7
+ import os
8
+
9
+ # ==========================================
10
+ # 1. ์„ค์ • (ํ•™์Šต ๋•Œ์™€ ๋งž์ถฐ์•ผ ์ž˜ ๋‚˜์˜ต๋‹ˆ๋‹ค)
11
+ # ==========================================
12
+ BASE_MODEL = "runwayml/stable-diffusion-v1-5"
13
+ MOTION_ADAPTER = "guoyww/animatediff-motion-adapter-v1-5-2"
14
+ LORA_PATH = "output_drag_lora" # ๋ฐฉ๊ธˆ ํ•™์Šต ๋๋‚œ ํด๋”
15
+ TEST_IMAGE_PATH = "test_input.png" # ์ค€๋น„ํ•œ ์ด๋ฏธ์ง€
16
+
17
+ # ํ•™์Šต ๋•Œ ์ผ๋˜ ํ”„๋กฌํ”„ํŠธ ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉ (์ค‘์š”!)
18
+ PROMPT = "timelapse clouds moving in the sky, cinematic, high quality, 4k"
19
+ NEGATIVE_PROMPT = "bad quality, worst quality, blurry, low resolution, distortion, watermark"
20
+
21
+
22
+ # ==========================================
23
+ # 2. ์ƒ‰์ƒ ๋ณด์ • ์ €์žฅ ํ•จ์ˆ˜ (ํŒŒ๋ž€์ƒ‰ ๋ฐฉ์ง€)
24
+ # ==========================================
25
+ def save_video_fixed(frames, path, fps=8):
26
+ height, width, _ = np.array(frames[0]).shape
27
+ # OpenCV ๋น„๋””์˜ค ์ž‘์„ฑ๊ธฐ
28
+ out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
29
+
30
+ for frame in frames:
31
+ # PIL -> Numpy ๋ณ€ํ™˜
32
+ img_np = np.array(frame)
33
+ # RGB -> BGR ๋ณ€ํ™˜ (์ด๊ฒŒ ํ•ต์‹ฌ!)
34
+ img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
35
+ out.write(img_bgr)
36
+
37
+ out.release()
38
+ print(f"โœจ ์˜์ƒ ์ €์žฅ ์™„๋ฃŒ: {path}")
39
+
40
+
41
+ # ==========================================
42
+ # 3. ๋ฉ”์ธ ์‹คํ–‰ ํ•จ์ˆ˜
43
+ # ==========================================
44
+ def main():
45
+ device = "cuda" if torch.cuda.is_available() else "cpu"
46
+ print(f"Using device: {device}")
47
+
48
+ # 1. ๋ชจ๋ธ ๋กœ๋“œ
49
+ print("Loading Base Model...")
50
+ adapter = MotionAdapter.from_pretrained(MOTION_ADAPTER)
51
+ pipe = AnimateDiffPipeline.from_pretrained(
52
+ BASE_MODEL,
53
+ motion_adapter=adapter,
54
+ torch_dtype=torch.float16
55
+ ).to(device)
56
+
57
+ # ์Šค์ผ€์ค„๋Ÿฌ ์„ค์ •
58
+ pipe.scheduler = DDIMScheduler.from_pretrained(
59
+ BASE_MODEL,
60
+ subfolder="scheduler",
61
+ clip_sample=False,
62
+ timestep_spacing="linspace",
63
+ beta_schedule="linear",
64
+ steps_offset=1
65
+ )
66
+
67
+ # 2. ํ•™์Šตํ•œ LoRA ๋ถˆ๋Ÿฌ์˜ค๊ธฐ (์„ฑ์ ํ‘œ ํ™•์ธ)
68
+ print(f"Loading LoRA from {LORA_PATH}...")
69
+ try:
70
+ pipe.unet.load_attn_procs(LORA_PATH)
71
+ print("โœ… LoRA Load Success!")
72
+ except Exception as e:
73
+ print(f"โŒ LoRA Load Failed: {e}")
74
+ return
75
+
76
+ # 3. ์ด๋ฏธ์ง€ ์ค€๋น„
77
+ if not os.path.exists(TEST_IMAGE_PATH):
78
+ print(f"โš ๏ธ {TEST_IMAGE_PATH}๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค! ๊ฒ€์€ ํ™”๋ฉด์œผ๋กœ ํ…Œ์ŠคํŠธํ•ฉ๋‹ˆ๋‹ค.")
79
+ input_image = Image.new('RGB', (256, 256), color='black')
80
+ else:
81
+ input_image = Image.open(TEST_IMAGE_PATH).convert("RGB")
82
+ input_image = input_image.resize((256, 256)) # ํ•™์Šต ํ•ด์ƒ๋„ ๋งž์ถค
83
+
84
+ # 4. ์˜์ƒ ์ƒ์„ฑ (Inference)
85
+ print("Generating Video... (์•ฝ 1๋ถ„ ์†Œ์š”)")
86
+
87
+ # ์‹œ๋“œ ๊ณ ์ • (๋งค๋ฒˆ ๋˜‘๊ฐ™์ด ์ž˜ ๋‚˜์˜ค๊ฒŒ ํ•˜๊ธฐ ์œ„ํ•ด)
88
+ generator = torch.Generator(device=device).manual_seed(42)
89
+
90
+ output = pipe(
91
+ prompt=PROMPT,
92
+ negative_prompt=NEGATIVE_PROMPT,
93
+ num_frames=16, # 2์ดˆ ์˜์ƒ
94
+ guidance_scale=7.5,
95
+ num_inference_steps=25, # 25๋ฒˆ๋งŒ ๊ทธ๋ ค๋„ ์ถฉ๋ถ„
96
+ generator=generator,
97
+ width=256,
98
+ height=256
99
+ )
100
+
101
+ frames = output.frames[0]
102
+
103
+ # 5. ์ €์žฅ (GIF + MP4)
104
+ export_to_gif(frames, "final_result.gif")
105
+ save_video_fixed(frames, "final_result.mp4", fps=8)
106
+
107
+ print("๐ŸŽ‰ ๋ชจ๋“  ์ž‘์—… ์™„๋ฃŒ! 'final_result.mp4'๋ฅผ ํ™•์ธํ•˜์„ธ์š”.")
108
+
109
+
110
+ if __name__ == "__main__":
111
+ main()
output_drag_lora/checkpoint-200/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3dc451b5b80d8f93d87dc2a4804e24ed2ca8d950dcae1bd8973fdf7520fab237
3
+ size 5269637960
output_drag_lora/checkpoint-200/optimizer.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd45978dd2af53b0e7fa1c34949f65ec801a87fb3d75b19420692cc61ef3b8ef
3
+ size 32586114
output_drag_lora/checkpoint-200/random_states_0.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c9fe4cbbb6dd28903cee54423e7adeb58305fed600351ad550648a40b914db57
3
+ size 14344
output_drag_lora/checkpoint-200/scaler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c5e66cd0701a17acd6a6484ee4019ae1d960fac369fffe82987ea5dd7915828
3
+ size 988
output_drag_lora/checkpoint-400/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:55badd86b56c1084d6ff85506e77c95303c35a50ea4c27a476ee9ad1faafc05d
3
+ size 5269637960
output_drag_lora/checkpoint-400/optimizer.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d0a7048d044782e7341cbfb89b12b6d552d4d9914bfc2d2750c0f7ae9a73748c
3
+ size 32586114
output_drag_lora/checkpoint-400/random_states_0.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c24300fc53e165ac7325140a2cfe55dcd85aa244509ddf785381b83bc638c02b
3
+ size 14344
output_drag_lora/checkpoint-400/scaler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ed35716f7a9275082254cfbe153b60686b76c261de984a9e4d37dfb27c8ef24
3
+ size 988
output_drag_lora/checkpoint-600/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52a22baa43787af67caba82823f0553d510813387a7e5d801e872fff37a2b781
3
+ size 5269637960
output_drag_lora/checkpoint-600/optimizer.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92514ffaa22d74392ed8688a8a0535d92c1be145410e80ef12f44314a5465f24
3
+ size 32586114
output_drag_lora/checkpoint-600/random_states_0.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b68c4c937eb2a1de34c6c4c7a4892625c0541036db3d34daa8391c3bcc16dc9f
3
+ size 14344
output_drag_lora/checkpoint-600/scaler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:30cc72c17790183d209119d8c6a06b1b84e2e99ea44df74f41e2a84800dba3b3
3
+ size 988
output_drag_lora/pytorch_lora_weights.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e6d416718ea78e54b1921d3ca3df89c559743b712e8f1ac61253321d84229f5
3
+ size 16128712
requirements.txt ADDED
Binary file (4.74 kB). View file