Update app.py
Browse files
app.py
CHANGED
|
@@ -212,57 +212,78 @@ class ClimatePredictor:
|
|
| 212 |
transforms.Normalize(mean=[0.5], std=[0.5])
|
| 213 |
])
|
| 214 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
def predict_from_inputs(self, rgb_image, ndvi_image, terrain_image,
|
| 216 |
elevation_data, wind_speed, wind_direction,
|
| 217 |
temperature, humidity):
|
| 218 |
"""Gradio ์ธํฐํ์ด์ค์ฉ ์์ธก ํจ์"""
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
|
| 265 |
-
# Gradio ์ธํฐํ์ด์ค ์์ฑ
|
| 266 |
def create_gradio_interface():
|
| 267 |
predictor = ClimatePredictor('best_model.pth')
|
| 268 |
|
|
@@ -282,8 +303,8 @@ def create_gradio_interface():
|
|
| 282 |
fn=predict_and_visualize,
|
| 283 |
inputs=[
|
| 284 |
gr.Image(label="RGB Satellite Image", type="numpy"),
|
| 285 |
-
gr.Image(label="NDVI Image", type="numpy"),
|
| 286 |
-
gr.Image(label="Terrain Map", type="numpy"),
|
| 287 |
gr.File(label="Elevation Data (NPY file)"),
|
| 288 |
gr.Number(label="Wind Speed (m/s)", value=5.0),
|
| 289 |
gr.Number(label="Wind Direction (degrees)", value=180.0),
|
|
@@ -292,7 +313,8 @@ def create_gradio_interface():
|
|
| 292 |
],
|
| 293 |
outputs=gr.Plot(label="Prediction Results"),
|
| 294 |
title="Renewable Energy Potential Predictor",
|
| 295 |
-
description="Upload satellite imagery and environmental data to predict wind and solar power potential.
|
|
|
|
| 296 |
examples=[
|
| 297 |
[
|
| 298 |
"examples/rgb_example.png",
|
|
@@ -305,6 +327,7 @@ def create_gradio_interface():
|
|
| 305 |
)
|
| 306 |
return interface
|
| 307 |
|
|
|
|
| 308 |
if __name__ == "__main__":
|
| 309 |
interface = create_gradio_interface()
|
| 310 |
-
interface.launch()
|
|
|
|
| 212 |
transforms.Normalize(mean=[0.5], std=[0.5])
|
| 213 |
])
|
| 214 |
|
| 215 |
+
def convert_to_single_channel(self, image_array):
|
| 216 |
+
"""RGB ์ด๋ฏธ์ง๋ฅผ ๋จ์ผ ์ฑ๋๋ก ๋ณํ"""
|
| 217 |
+
if len(image_array.shape) == 3:
|
| 218 |
+
# RGB to grayscale conversion
|
| 219 |
+
return np.dot(image_array[...,:3], [0.2989, 0.5870, 0.1140])
|
| 220 |
+
return image_array
|
| 221 |
+
|
| 222 |
def predict_from_inputs(self, rgb_image, ndvi_image, terrain_image,
|
| 223 |
elevation_data, wind_speed, wind_direction,
|
| 224 |
temperature, humidity):
|
| 225 |
"""Gradio ์ธํฐํ์ด์ค์ฉ ์์ธก ํจ์"""
|
| 226 |
+
try:
|
| 227 |
+
# RGB ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ
|
| 228 |
+
rgb_tensor = self.rgb_transform(Image.fromarray(rgb_image)).unsqueeze(0)
|
| 229 |
+
|
| 230 |
+
# NDVI ์ด๋ฏธ์ง๋ฅผ ๋จ์ผ ์ฑ๋๋ก ๋ณํ ํ ์ ์ฒ๋ฆฌ
|
| 231 |
+
ndvi_gray = self.convert_to_single_channel(ndvi_image)
|
| 232 |
+
ndvi_tensor = self.single_channel_transform(Image.fromarray(ndvi_gray.astype(np.uint8))).unsqueeze(0)
|
| 233 |
+
|
| 234 |
+
# Terrain ์ด๋ฏธ์ง๋ฅผ ๋จ์ผ ์ฑ๋๋ก ๋ณํ ํ ์ ์ฒ๋ฆฌ
|
| 235 |
+
terrain_gray = self.convert_to_single_channel(terrain_image)
|
| 236 |
+
terrain_tensor = self.single_channel_transform(Image.fromarray(terrain_gray.astype(np.uint8))).unsqueeze(0)
|
| 237 |
+
|
| 238 |
+
# Print shapes for debugging
|
| 239 |
+
print(f"RGB tensor shape: {rgb_tensor.shape}")
|
| 240 |
+
print(f"NDVI tensor shape: {ndvi_tensor.shape}")
|
| 241 |
+
print(f"Terrain tensor shape: {terrain_tensor.shape}")
|
| 242 |
+
|
| 243 |
+
# ๊ณ ๋ ๋ฐ์ดํฐ ์ฒ๋ฆฌ
|
| 244 |
+
elevation_tensor = torch.from_numpy(elevation_data).float().unsqueeze(0).unsqueeze(0)
|
| 245 |
+
elevation_tensor = (elevation_tensor - elevation_tensor.min()) / (elevation_tensor.max() - elevation_tensor.min())
|
| 246 |
+
|
| 247 |
+
# ๊ธฐ์ ๋ฐ์ดํฐ ์ฒ๋ฆฌ
|
| 248 |
+
weather_features = np.array([wind_speed, wind_direction, temperature, humidity])
|
| 249 |
+
weather_features = (weather_features - weather_features.min()) / (weather_features.max() - weather_features.min())
|
| 250 |
+
weather_features = torch.tensor(weather_features, dtype=torch.float32).unsqueeze(0)
|
| 251 |
+
|
| 252 |
+
# ๋๋ฐ์ด์ค๋ก ์ด๋
|
| 253 |
+
sample = {
|
| 254 |
+
'rgb': rgb_tensor.to(self.device),
|
| 255 |
+
'ndvi': ndvi_tensor.to(self.device),
|
| 256 |
+
'terrain': terrain_tensor.to(self.device),
|
| 257 |
+
'elevation': elevation_tensor.to(self.device),
|
| 258 |
+
'weather_features': weather_features.to(self.device)
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
# ์์ธก
|
| 262 |
+
with torch.no_grad():
|
| 263 |
+
wind_pred, solar_pred = self.model(sample)
|
| 264 |
+
|
| 265 |
+
# ๊ฒฐ๊ณผ๋ฅผ numpy ๋ฐฐ์ด๋ก ๋ณํ
|
| 266 |
+
wind_map = wind_pred.cpu().numpy()[0, 0]
|
| 267 |
+
solar_map = solar_pred.cpu().numpy()[0, 0]
|
| 268 |
+
|
| 269 |
+
# ๊ฒฐ๊ณผ ์๊ฐํ
|
| 270 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
|
| 271 |
+
|
| 272 |
+
# ํ๋ ฅ ๋ฐ์ ์ ์ฌ๋ ์๊ฐํ
|
| 273 |
+
sns.heatmap(wind_map, ax=ax1, cmap='YlOrRd', cbar_kws={'label': 'Wind Power Potential'})
|
| 274 |
+
ax1.set_title('Wind Power Potential Map')
|
| 275 |
+
|
| 276 |
+
# ํ์๊ด ๋ฐ์ ์ ์ฌ๋ ์๊ฐํ
|
| 277 |
+
sns.heatmap(solar_map, ax=ax2, cmap='YlOrRd', cbar_kws={'label': 'Solar Power Potential'})
|
| 278 |
+
ax2.set_title('Solar Power Potential Map')
|
| 279 |
+
|
| 280 |
+
plt.tight_layout()
|
| 281 |
+
|
| 282 |
+
return fig
|
| 283 |
+
except Exception as e:
|
| 284 |
+
print(f"Error in prediction: {str(e)}")
|
| 285 |
+
raise e
|
| 286 |
|
|
|
|
| 287 |
def create_gradio_interface():
|
| 288 |
predictor = ClimatePredictor('best_model.pth')
|
| 289 |
|
|
|
|
| 303 |
fn=predict_and_visualize,
|
| 304 |
inputs=[
|
| 305 |
gr.Image(label="RGB Satellite Image", type="numpy"),
|
| 306 |
+
gr.Image(label="NDVI Image (will be converted to grayscale)", type="numpy"),
|
| 307 |
+
gr.Image(label="Terrain Map (will be converted to grayscale)", type="numpy"),
|
| 308 |
gr.File(label="Elevation Data (NPY file)"),
|
| 309 |
gr.Number(label="Wind Speed (m/s)", value=5.0),
|
| 310 |
gr.Number(label="Wind Direction (degrees)", value=180.0),
|
|
|
|
| 313 |
],
|
| 314 |
outputs=gr.Plot(label="Prediction Results"),
|
| 315 |
title="Renewable Energy Potential Predictor",
|
| 316 |
+
description="""Upload satellite imagery and environmental data to predict wind and solar power potential.
|
| 317 |
+
Note: NDVI and Terrain images will be automatically converted to grayscale.""",
|
| 318 |
examples=[
|
| 319 |
[
|
| 320 |
"examples/rgb_example.png",
|
|
|
|
| 327 |
)
|
| 328 |
return interface
|
| 329 |
|
| 330 |
+
# Hugging Face Spaces์์ ์ฑ ์คํ
|
| 331 |
if __name__ == "__main__":
|
| 332 |
interface = create_gradio_interface()
|
| 333 |
+
interface.launch()
|