HoeioUser commited on
Commit
7c44bb7
·
verified ·
1 Parent(s): 061e569

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -18
app.py CHANGED
@@ -343,6 +343,24 @@ def load_examples_from_directory(base_dir):
343
 
344
  def create_gradio_interface():
345
  predictor = ClimatePredictor('best_model.pth')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
 
347
  with gr.Blocks() as interface:
348
  gr.Markdown("# Renewable Energy Potential Predictor")
@@ -351,10 +369,10 @@ def create_gradio_interface():
351
  with gr.Row():
352
  # 입력 섹션 (1/3 크기)
353
  with gr.Column(scale=1):
354
- rgb_input = gr.Image(label="RGB Satellite Image", type="numpy")
355
- ndvi_input = gr.Image(label="NDVI Image", type="numpy")
356
- terrain_input = gr.Image(label="Terrain Map", type="numpy")
357
- elevation_input = gr.File(label="Elevation Data (NPY file)")
358
 
359
  with gr.Row():
360
  wind_speed = gr.Number(label="Wind Speed (m/s)", value=5.0)
@@ -369,29 +387,36 @@ def create_gradio_interface():
369
  # 출력 섹션 (2/3 크기)
370
  with gr.Column(scale=2):
371
  output_plot = gr.Plot(label="Prediction Results")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
 
373
  # 예측 버튼 클릭 이벤트 연결
374
  predict_btn.click(
375
- fn=predictor.predict_from_inputs,
376
  inputs=[rgb_input, ndvi_input, terrain_input, elevation_input,
377
  wind_speed, wind_direction, temperature, humidity],
378
  outputs=output_plot
379
  )
380
-
381
- # 예제 섹션
382
- examples = load_examples_from_directory("filtered_climate_data")
383
- gr.Examples(
384
- examples=examples,
385
- inputs=[rgb_input, ndvi_input, terrain_input, elevation_input,
386
- wind_speed, wind_direction, temperature, humidity],
387
- outputs=output_plot,
388
- fn=predictor.predict_from_inputs, # 예제 실행에 사용할 함수 지정
389
- cache_examples=True,
390
- api_name=False # API 엔드포인트 생성 비활성화
391
- )
392
 
393
  return interface
394
 
395
  if __name__ == "__main__":
396
  interface = create_gradio_interface()
397
- interface.launch()
 
 
 
 
343
 
344
  def create_gradio_interface():
345
  predictor = ClimatePredictor('best_model.pth')
346
+
347
+ def process_elevation_file(file_obj):
348
+ if isinstance(file_obj, str): # 파일 경로인 경우
349
+ return np.load(file_obj)
350
+ else: # UploadedFile 객체인 경우
351
+ return np.load(file_obj.name)
352
+
353
+ def predict_with_processing(*args):
354
+ rgb_image, ndvi_image, terrain_image, elevation_file = args[:4]
355
+ weather_params = args[4:]
356
+
357
+ # elevation 파일 처리
358
+ elevation_data = process_elevation_file(elevation_file)
359
+
360
+ return predictor.predict_from_inputs(
361
+ rgb_image, ndvi_image, terrain_image, elevation_data,
362
+ *weather_params
363
+ )
364
 
365
  with gr.Blocks() as interface:
366
  gr.Markdown("# Renewable Energy Potential Predictor")
 
369
  with gr.Row():
370
  # 입력 섹션 (1/3 크기)
371
  with gr.Column(scale=1):
372
+ rgb_input = gr.Image(label="RGB Satellite Image", type="numpy", height=200)
373
+ ndvi_input = gr.Image(label="NDVI Image", type="numpy", height=200)
374
+ terrain_input = gr.Image(label="Terrain Map", type="numpy", height=200)
375
+ elevation_input = gr.File(label="Elevation Data (NPY file)", height=50)
376
 
377
  with gr.Row():
378
  wind_speed = gr.Number(label="Wind Speed (m/s)", value=5.0)
 
387
  # 출력 섹션 (2/3 크기)
388
  with gr.Column(scale=2):
389
  output_plot = gr.Plot(label="Prediction Results")
390
+
391
+ # 예제 데이터 로드
392
+ examples = load_examples_from_directory("filtered_climate_data")
393
+
394
+ # 예제 갤러리 생성
395
+ with gr.Row():
396
+ gr.Examples(
397
+ examples=examples,
398
+ inputs=[rgb_input, ndvi_input, terrain_input, elevation_input,
399
+ wind_speed, wind_direction, temperature, humidity],
400
+ outputs=output_plot,
401
+ fn=predict_with_processing,
402
+ cache_examples=True,
403
+ label="Click any example to run",
404
+ examples_per_page=5
405
+ )
406
 
407
  # 예측 버튼 클릭 이벤트 연결
408
  predict_btn.click(
409
+ fn=predict_with_processing,
410
  inputs=[rgb_input, ndvi_input, terrain_input, elevation_input,
411
  wind_speed, wind_direction, temperature, humidity],
412
  outputs=output_plot
413
  )
 
 
 
 
 
 
 
 
 
 
 
 
414
 
415
  return interface
416
 
417
  if __name__ == "__main__":
418
  interface = create_gradio_interface()
419
+ interface.launch(
420
+ share=True,
421
+ enable_queue=True
422
+ )