Update app.py
Browse files
app.py
CHANGED
|
@@ -343,6 +343,24 @@ def load_examples_from_directory(base_dir):
|
|
| 343 |
|
| 344 |
def create_gradio_interface():
|
| 345 |
predictor = ClimatePredictor('best_model.pth')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
|
| 347 |
with gr.Blocks() as interface:
|
| 348 |
gr.Markdown("# Renewable Energy Potential Predictor")
|
|
@@ -351,10 +369,10 @@ def create_gradio_interface():
|
|
| 351 |
with gr.Row():
|
| 352 |
# 입력 섹션 (1/3 크기)
|
| 353 |
with gr.Column(scale=1):
|
| 354 |
-
rgb_input = gr.Image(label="RGB Satellite Image", type="numpy")
|
| 355 |
-
ndvi_input = gr.Image(label="NDVI Image", type="numpy")
|
| 356 |
-
terrain_input = gr.Image(label="Terrain Map", type="numpy")
|
| 357 |
-
elevation_input = gr.File(label="Elevation Data (NPY file)")
|
| 358 |
|
| 359 |
with gr.Row():
|
| 360 |
wind_speed = gr.Number(label="Wind Speed (m/s)", value=5.0)
|
|
@@ -369,29 +387,36 @@ def create_gradio_interface():
|
|
| 369 |
# 출력 섹션 (2/3 크기)
|
| 370 |
with gr.Column(scale=2):
|
| 371 |
output_plot = gr.Plot(label="Prediction Results")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 372 |
|
| 373 |
# 예측 버튼 클릭 이벤트 연결
|
| 374 |
predict_btn.click(
|
| 375 |
-
fn=
|
| 376 |
inputs=[rgb_input, ndvi_input, terrain_input, elevation_input,
|
| 377 |
wind_speed, wind_direction, temperature, humidity],
|
| 378 |
outputs=output_plot
|
| 379 |
)
|
| 380 |
-
|
| 381 |
-
# 예제 섹션
|
| 382 |
-
examples = load_examples_from_directory("filtered_climate_data")
|
| 383 |
-
gr.Examples(
|
| 384 |
-
examples=examples,
|
| 385 |
-
inputs=[rgb_input, ndvi_input, terrain_input, elevation_input,
|
| 386 |
-
wind_speed, wind_direction, temperature, humidity],
|
| 387 |
-
outputs=output_plot,
|
| 388 |
-
fn=predictor.predict_from_inputs, # 예제 실행에 사용할 함수 지정
|
| 389 |
-
cache_examples=True,
|
| 390 |
-
api_name=False # API 엔드포인트 생성 비활성화
|
| 391 |
-
)
|
| 392 |
|
| 393 |
return interface
|
| 394 |
|
| 395 |
if __name__ == "__main__":
|
| 396 |
interface = create_gradio_interface()
|
| 397 |
-
interface.launch(
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
|
| 344 |
def create_gradio_interface():
|
| 345 |
predictor = ClimatePredictor('best_model.pth')
|
| 346 |
+
|
| 347 |
+
def process_elevation_file(file_obj):
|
| 348 |
+
if isinstance(file_obj, str): # 파일 경로인 경우
|
| 349 |
+
return np.load(file_obj)
|
| 350 |
+
else: # UploadedFile 객체인 경우
|
| 351 |
+
return np.load(file_obj.name)
|
| 352 |
+
|
| 353 |
+
def predict_with_processing(*args):
|
| 354 |
+
rgb_image, ndvi_image, terrain_image, elevation_file = args[:4]
|
| 355 |
+
weather_params = args[4:]
|
| 356 |
+
|
| 357 |
+
# elevation 파일 처리
|
| 358 |
+
elevation_data = process_elevation_file(elevation_file)
|
| 359 |
+
|
| 360 |
+
return predictor.predict_from_inputs(
|
| 361 |
+
rgb_image, ndvi_image, terrain_image, elevation_data,
|
| 362 |
+
*weather_params
|
| 363 |
+
)
|
| 364 |
|
| 365 |
with gr.Blocks() as interface:
|
| 366 |
gr.Markdown("# Renewable Energy Potential Predictor")
|
|
|
|
| 369 |
with gr.Row():
|
| 370 |
# 입력 섹션 (1/3 크기)
|
| 371 |
with gr.Column(scale=1):
|
| 372 |
+
rgb_input = gr.Image(label="RGB Satellite Image", type="numpy", height=200)
|
| 373 |
+
ndvi_input = gr.Image(label="NDVI Image", type="numpy", height=200)
|
| 374 |
+
terrain_input = gr.Image(label="Terrain Map", type="numpy", height=200)
|
| 375 |
+
elevation_input = gr.File(label="Elevation Data (NPY file)", height=50)
|
| 376 |
|
| 377 |
with gr.Row():
|
| 378 |
wind_speed = gr.Number(label="Wind Speed (m/s)", value=5.0)
|
|
|
|
| 387 |
# 출력 섹션 (2/3 크기)
|
| 388 |
with gr.Column(scale=2):
|
| 389 |
output_plot = gr.Plot(label="Prediction Results")
|
| 390 |
+
|
| 391 |
+
# 예제 데이터 로드
|
| 392 |
+
examples = load_examples_from_directory("filtered_climate_data")
|
| 393 |
+
|
| 394 |
+
# 예제 갤러리 생성
|
| 395 |
+
with gr.Row():
|
| 396 |
+
gr.Examples(
|
| 397 |
+
examples=examples,
|
| 398 |
+
inputs=[rgb_input, ndvi_input, terrain_input, elevation_input,
|
| 399 |
+
wind_speed, wind_direction, temperature, humidity],
|
| 400 |
+
outputs=output_plot,
|
| 401 |
+
fn=predict_with_processing,
|
| 402 |
+
cache_examples=True,
|
| 403 |
+
label="Click any example to run",
|
| 404 |
+
examples_per_page=5
|
| 405 |
+
)
|
| 406 |
|
| 407 |
# 예측 버튼 클릭 이벤트 연결
|
| 408 |
predict_btn.click(
|
| 409 |
+
fn=predict_with_processing,
|
| 410 |
inputs=[rgb_input, ndvi_input, terrain_input, elevation_input,
|
| 411 |
wind_speed, wind_direction, temperature, humidity],
|
| 412 |
outputs=output_plot
|
| 413 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 414 |
|
| 415 |
return interface
|
| 416 |
|
| 417 |
if __name__ == "__main__":
|
| 418 |
interface = create_gradio_interface()
|
| 419 |
+
interface.launch(
|
| 420 |
+
share=True,
|
| 421 |
+
enable_queue=True
|
| 422 |
+
)
|