ginipick commited on
Commit
569bdec
·
verified ·
1 Parent(s): c057ff1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +250 -0
app.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio_client import Client, handle_file
3
+ import logging
4
+ import traceback
5
+ from datetime import datetime
6
+ import json
7
+
8
+ # 로깅 설정
9
+ logging.basicConfig(
10
+ level=logging.DEBUG,
11
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
12
+ )
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # API 엔드포인트
16
+ API_URL = "http://211.233.58.201:7788/"
17
+
18
+ def test_api_connection():
19
+ """API 연결 테스트"""
20
+ try:
21
+ client = Client(API_URL)
22
+ logger.info(f"Successfully connected to API: {API_URL}")
23
+ return True, "API 연결 성공!"
24
+ except Exception as e:
25
+ logger.error(f"Failed to connect to API: {str(e)}")
26
+ return False, f"API 연결 실패: {str(e)}"
27
+
28
+ def generate_animation(image, audio, guidance_scale, steps, progress=gr.Progress()):
29
+ """애니메이션 생성 함수"""
30
+ logger.info("=== 애니메이션 생성 시작 ===")
31
+ logs = []
32
+
33
+ try:
34
+ # 입력 파라미터 로깅
35
+ log_msg = f"입력 파라미터:\n- Image: {image}\n- Audio: {audio}\n- Guidance Scale: {guidance_scale}\n- Steps: {steps}"
36
+ logger.info(log_msg)
37
+ logs.append(f"[{datetime.now().strftime('%H:%M:%S')}] {log_msg}")
38
+
39
+ # 입력 검증
40
+ if image is None:
41
+ error_msg = "이미지가 제공되지 않았습니다."
42
+ logger.error(error_msg)
43
+ logs.append(f"[{datetime.now().strftime('%H:%M:%S')}] ERROR: {error_msg}")
44
+ return None, None, "\n".join(logs)
45
+
46
+ if audio is None:
47
+ error_msg = "오디오가 제공되지 않았습니다."
48
+ logger.error(error_msg)
49
+ logs.append(f"[{datetime.now().strftime('%H:%M:%S')}] ERROR: {error_msg}")
50
+ return None, None, "\n".join(logs)
51
+
52
+ # Progress 업데이트
53
+ progress(0.1, desc="API 클라이언트 초기화 중...")
54
+
55
+ # API 클라이언트 생성
56
+ client = Client(API_URL)
57
+ logs.append(f"[{datetime.now().strftime('%H:%M:%S')}] API 클라이언트 생성 완료")
58
+
59
+ # Progress 업데이트
60
+ progress(0.3, desc="파일 핸들 생성 중...")
61
+
62
+ # 파일 핸들 생성
63
+ image_handle = handle_file(image)
64
+ audio_handle = handle_file(audio)
65
+
66
+ log_msg = f"파일 핸들 생성 완료:\n- Image handle: {type(image_handle)}\n- Audio handle: {type(audio_handle)}"
67
+ logger.debug(log_msg)
68
+ logs.append(f"[{datetime.now().strftime('%H:%M:%S')}] {log_msg}")
69
+
70
+ # Progress 업데이트
71
+ progress(0.5, desc="API 호출 중... (이 과정은 시간이 걸릴 수 있습니다)")
72
+
73
+ # API 호출
74
+ logger.info("API 호출 시작")
75
+ logs.append(f"[{datetime.now().strftime('%H:%M:%S')}] API 호출 시작...")
76
+
77
+ result = client.predict(
78
+ image_path=image_handle,
79
+ audio_path=audio_handle,
80
+ guidance_scale=guidance_scale,
81
+ steps=steps,
82
+ api_name="/generate_animation"
83
+ )
84
+
85
+ # Progress 업데이트
86
+ progress(0.9, desc="결과 처리 중...")
87
+
88
+ # 결과 로깅
89
+ log_msg = f"API 호출 성공!\n결과 타입: {type(result)}\n결과 길이: {len(result) if isinstance(result, (list, tuple)) else 'N/A'}"
90
+ logger.info(log_msg)
91
+ logs.append(f"[{datetime.now().strftime('%H:%M:%S')}] {log_msg}")
92
+
93
+ # 결과 상세 로깅
94
+ if isinstance(result, (list, tuple)) and len(result) >= 2:
95
+ for i, item in enumerate(result):
96
+ log_msg = f"결과[{i}]: {type(item)} - {str(item)[:100]}..."
97
+ logger.debug(log_msg)
98
+ logs.append(f"[{datetime.now().strftime('%H:%M:%S')}] {log_msg}")
99
+
100
+ animation_result = result[0]
101
+ comparison_result = result[1]
102
+
103
+ # 비디오 경로 추출
104
+ animation_video = None
105
+ comparison_video = None
106
+
107
+ if isinstance(animation_result, dict) and 'video' in animation_result:
108
+ animation_video = animation_result['video']
109
+ logs.append(f"[{datetime.now().strftime('%H:%M:%S')}] 애니메이션 비디오 경로: {animation_video}")
110
+
111
+ if isinstance(comparison_result, dict) and 'video' in comparison_result:
112
+ comparison_video = comparison_result['video']
113
+ logs.append(f"[{datetime.now().strftime('%H:%M:%S')}] 비교 비디오 경로: {comparison_video}")
114
+
115
+ logs.append(f"[{datetime.now().strftime('%H:%M:%S')}] === 애니메이션 생성 완료! ===")
116
+
117
+ return animation_video, comparison_video, "\n".join(logs)
118
+ else:
119
+ error_msg = f"예상치 못한 결과 형식: {type(result)}"
120
+ logger.error(error_msg)
121
+ logs.append(f"[{datetime.now().strftime('%H:%M:%S')}] ERROR: {error_msg}")
122
+ return None, None, "\n".join(logs)
123
+
124
+ except Exception as e:
125
+ error_msg = f"오류 발생: {str(e)}"
126
+ logger.error(error_msg)
127
+ logger.error(traceback.format_exc())
128
+ logs.append(f"[{datetime.now().strftime('%H:%M:%S')}] ERROR: {error_msg}")
129
+ logs.append(f"[{datetime.now().strftime('%H:%M:%S')}] 상세 오류:\n{traceback.format_exc()}")
130
+ return None, None, "\n".join(logs)
131
+
132
+ # Gradio 인터페이스 생성
133
+ with gr.Blocks(title="Animation Generator API Test") as demo:
134
+ gr.Markdown("""
135
+ # 🎬 Animation Generator API Test Interface
136
+
137
+ 이 인터페이스는 `http://211.233.58.201:7788/` API를 테스트하기 위한 도구입니다.
138
+
139
+ ## 사용 방법:
140
+ 1. 포트레이트 이미지를 업로드하세요
141
+ 2. 드라이빙 오디오 파일을 업로드하세요
142
+ 3. Guidance Scale과 Inference Steps를 조정하세요
143
+ 4. "Generate Animation" 버튼을 클릭하세요
144
+ """)
145
+
146
+ # API 연결 상태 확인
147
+ with gr.Row():
148
+ with gr.Column():
149
+ connection_status = gr.Textbox(label="API 연결 상태", interactive=False)
150
+ check_connection_btn = gr.Button("API 연결 테스트", variant="secondary")
151
+
152
+ gr.Markdown("---")
153
+
154
+ # 입력 섹션
155
+ with gr.Row():
156
+ with gr.Column():
157
+ image_input = gr.Image(
158
+ label="Portrait Image (any aspect ratio)",
159
+ type="filepath",
160
+ elem_id="image_input"
161
+ )
162
+ audio_input = gr.Audio(
163
+ label="Driving Audio",
164
+ type="filepath",
165
+ elem_id="audio_input"
166
+ )
167
+
168
+ with gr.Column():
169
+ guidance_scale = gr.Slider(
170
+ minimum=1,
171
+ maximum=10,
172
+ value=3,
173
+ step=0.1,
174
+ label="Guidance Scale",
175
+ info="Controls the strength of the guidance"
176
+ )
177
+ steps = gr.Slider(
178
+ minimum=1,
179
+ maximum=50,
180
+ value=10,
181
+ step=1,
182
+ label="Inference Steps",
183
+ info="Number of denoising steps"
184
+ )
185
+
186
+ generate_btn = gr.Button("🚀 Generate Animation", variant="primary", size="lg")
187
+
188
+ # 결과 섹션
189
+ gr.Markdown("## 📽️ Results")
190
+ with gr.Row():
191
+ with gr.Column():
192
+ animation_output = gr.Video(
193
+ label="Animation Result",
194
+ elem_id="animation_output"
195
+ )
196
+ with gr.Column():
197
+ comparison_output = gr.Video(
198
+ label="Side-by-Side Comparison",
199
+ elem_id="comparison_output"
200
+ )
201
+
202
+ # 로그 섹션
203
+ with gr.Accordion("📋 실행 로그", open=True):
204
+ logs_output = gr.Textbox(
205
+ label="Logs",
206
+ lines=10,
207
+ max_lines=20,
208
+ interactive=False,
209
+ elem_id="logs"
210
+ )
211
+
212
+ # 예제 섹션
213
+ gr.Markdown("## 🎯 Examples")
214
+ gr.Examples(
215
+ examples=[
216
+ ["example_portrait.jpg", "example_audio.wav", 3.0, 10],
217
+ ["example_portrait2.jpg", "example_audio2.wav", 2.5, 15],
218
+ ],
219
+ inputs=[image_input, audio_input, guidance_scale, steps],
220
+ outputs=[animation_output, comparison_output, logs_output],
221
+ fn=generate_animation,
222
+ cache_examples=False
223
+ )
224
+
225
+ # 이벤트 핸들러
226
+ check_connection_btn.click(
227
+ fn=test_api_connection,
228
+ outputs=[connection_status, connection_status]
229
+ )
230
+
231
+ generate_btn.click(
232
+ fn=generate_animation,
233
+ inputs=[image_input, audio_input, guidance_scale, steps],
234
+ outputs=[animation_output, comparison_output, logs_output]
235
+ )
236
+
237
+ # 페이지 로드 시 API 연결 테스트
238
+ demo.load(
239
+ fn=test_api_connection,
240
+ outputs=[connection_status, connection_status]
241
+ )
242
+
243
+ # 앱 실행
244
+ if __name__ == "__main__":
245
+ demo.launch(
246
+ server_name="0.0.0.0",
247
+ server_port=7860,
248
+ share=False,
249
+ debug=True
250
+ )