Pingsz commited on
Commit
6394b10
·
verified ·
1 Parent(s): 4a439c2

Delete t.sh

Browse files
Files changed (1) hide show
  1. t.sh +0 -180
t.sh DELETED
@@ -1,180 +0,0 @@
1
- #!/bin/bash
2
- set -e
3
-
4
- # === Create folder structure ===
5
- mkdir -p geo-risk-space/src geo-risk-space/model geo-risk-space/keys
6
-
7
- # === model placeholder ===
8
- cat > geo-risk-space/model/README.txt <<'EOF'
9
- Place your trained model weights here as:
10
- geo-risk-space/model/geo_model.pth
11
- EOF
12
-
13
- # === src/model.py ===
14
- cat > geo-risk-space/src/model.py <<'EOF'
15
- import torch
16
- import torch.nn as nn
17
- import torch.nn.functional as F
18
- import torchvision.models as models
19
-
20
- class CompactGeoEmbed(nn.Module):
21
- def __init__(self, embed_c=32, proj_dim=96, pretrained=False):
22
- super().__init__()
23
- backbone = models.mobilenet_v2(
24
- weights=None if not pretrained else models.MobileNet_V2_Weights.IMAGENET1K_V1
25
- ).features
26
- self.backbone = backbone
27
- self.reduce = nn.Conv2d(1280, embed_c, 1)
28
- self.elev_conv = nn.Conv2d(1, embed_c, 3, padding=1)
29
- self.conv_head = nn.Sequential(
30
- nn.Conv2d(embed_c * 2, embed_c, 3, padding=1),
31
- nn.ReLU(True),
32
- nn.Conv2d(embed_c, embed_c, 3, padding=1),
33
- nn.ReLU(True),
34
- )
35
- self.proj = nn.Sequential(
36
- nn.AdaptiveAvgPool2d(1),
37
- nn.Flatten(),
38
- nn.Linear(embed_c, proj_dim),
39
- nn.ReLU(),
40
- nn.Linear(proj_dim, proj_dim),
41
- )
42
- self.risk_head = nn.Sequential(
43
- nn.Linear(proj_dim, proj_dim // 2),
44
- nn.ReLU(),
45
- nn.Linear(proj_dim // 2, 1),
46
- nn.Sigmoid(),
47
- )
48
-
49
- def forward(self, img, elev=None):
50
- x = self.backbone(img)
51
- x = self.reduce(x)
52
- if elev is None:
53
- elev = torch.zeros(x.size(0), 1, img.size(2), img.size(3), device=x.device)
54
- elev = F.interpolate(elev, size=x.shape[2:], mode="bilinear", align_corners=False)
55
- e = self.elev_conv(elev)
56
- x = self.conv_head(torch.cat([x, e], 1))
57
- p = self.proj(x)
58
- p = F.normalize(p, dim=1)
59
- risk = self.risk_head(p).squeeze(-1)
60
- return x, p, risk
61
- EOF
62
-
63
- # === app.py ===
64
- cat > geo-risk-space/app.py <<'EOF'
65
- import os, json, io, torch, requests, numpy as np, gradio as gr, ee
66
- from PIL import Image
67
- import torchvision.transforms as T
68
- from huggingface_hub import hf_hub_download
69
- from src.model import CompactGeoEmbed
70
-
71
- MODEL_LOCAL = "/app/model/geo_model.pth"
72
- HF_REPO_ID = "USERNAME/geo-risk-model"
73
- HF_FILENAME = "geo_model.pth"
74
- GEE_KEY_PATH = "/app/keys/gee_service_account.json"
75
-
76
- device = torch.device("cpu")
77
-
78
- def init_gee():
79
- if os.path.exists(GEE_KEY_PATH):
80
- with open(GEE_KEY_PATH) as f:
81
- svc = json.load(f)
82
- ee.Initialize(ee.ServiceAccountCredentials(svc["client_email"], GEE_KEY_PATH))
83
- print("✅ GEE initialized")
84
- else:
85
- print("⚠️ Missing GEE key, skipping Earth Engine init")
86
-
87
- def load_model():
88
- model = CompactGeoEmbed(32, 96)
89
- state = None
90
- if os.path.exists(MODEL_LOCAL):
91
- state = torch.load(MODEL_LOCAL, map_location="cpu")
92
- else:
93
- try:
94
- p = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_FILENAME)
95
- state = torch.load(p, map_location="cpu")
96
- except Exception as e:
97
- print("⚠️ No weights found:", e)
98
- if state:
99
- model.load_state_dict(state)
100
- model.to(device).eval()
101
- return model
102
-
103
- init_gee()
104
- MODEL = load_model()
105
- tf = T.Compose([T.Resize((128, 128)), T.ToTensor()])
106
-
107
- def fetch_satellite(lat, lon):
108
- p = ee.Geometry.Point([lon, lat])
109
- srtm = ee.Image("USGS/SRTMGL1_003")
110
- ndvi = ee.ImageCollection("MODIS/061/MOD13A2").select("NDVI").mean()
111
- rgb = (
112
- ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED")
113
- .filterBounds(p)
114
- .filterDate("2023-01-01", "2024-01-01")
115
- .sort("CLOUDY_PIXEL_PERCENTAGE")
116
- .first()
117
- )
118
- region = p.buffer(1000).bounds()
119
- elev_url = srtm.visualize(min=0, max=3000).getThumbURL(
120
- {"region": region, "dimensions": "128x128", "format": "png"}
121
- )
122
- ndvi_url = ndvi.visualize(min=-2000, max=10000, palette=["white", "green"]).getThumbURL(
123
- {"region": region, "dimensions": "128x128", "format": "png"}
124
- )
125
- elev = Image.open(io.BytesIO(requests.get(elev_url).content)).convert("L")
126
- ndvi_img = Image.open(io.BytesIO(requests.get(ndvi_url).content)).convert("RGB")
127
- return elev, ndvi_img
128
-
129
- def preprocess(img):
130
- if img is None:
131
- img = Image.new("RGB", (128, 128), (127, 127, 127))
132
- return tf(img.convert("RGB")).unsqueeze(0)
133
-
134
- def predict(lat, lon, img):
135
- lat, lon = float(lat), float(lon)
136
- try:
137
- elev, sat = fetch_satellite(lat, lon)
138
- except Exception as e:
139
- elev = Image.new("L", (128, 128), 127)
140
- sat = img
141
- x = preprocess(sat).to(device)
142
- e = torch.tensor(np.array(elev) / 255.0, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)
143
- with torch.no_grad():
144
- _, _, r = MODEL(x, e)
145
- return f"Predicted Risk Score: {float(r.item()):.4f}"
146
-
147
- with gr.Blocks() as demo:
148
- gr.Markdown("## 🌍 Geo-Risk Prediction (GEE-powered)")
149
- lat = gr.Number(value=51.5072, label="Latitude")
150
- lon = gr.Number(value=-0.1276, label="Longitude")
151
- img = gr.Image(type="pil", label="Optional RGB (Sentinel-2 fallback)")
152
- out = gr.Textbox(label="Prediction")
153
- gr.Button("Run").click(fn=predict, inputs=[lat, lon, img], outputs=out)
154
-
155
- if __name__ == "__main__":
156
- demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
157
- EOF
158
-
159
- # === requirements.txt ===
160
- cat > geo-risk-space/requirements.txt <<'EOF'
161
- torch
162
- torchvision
163
- gradio
164
- pillow
165
- requests
166
- huggingface_hub
167
- earthengine-api
168
- EOF
169
-
170
- # === README.md ===
171
- cat > geo-risk-space/README.md <<'EOF'
172
- # 🌍 Geo-Risk Prediction (GEE + Gradio)
173
-
174
- ### Setup
175
-
176
- ```bash
177
- bash setup_space.sh
178
- cd geo-risk-space
179
- pip install -r requirements.txt
180
- python app.py