pr0ximaCent commited on
Commit
d9a6cb8
·
verified ·
1 Parent(s): 70d9c39

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -50
app.py CHANGED
@@ -7,19 +7,16 @@ import torch.nn as nn
7
  import os
8
  import onnx
9
  import onnxruntime as ort
 
10
 
11
  # === Model Path ===
12
  MODEL_PATH = "bangla_disaster_model.pth"
 
13
 
14
- # Check if model exists
15
- if not os.path.exists(MODEL_PATH):
16
- st.error("❌ Model file not found. Please ensure bangla_disaster_model.pth is uploaded.")
17
- st.stop()
18
-
19
- # Global class list
20
  classes = ['HYD', 'MET', 'FD', 'EQ', 'OTHD']
21
 
22
- # === Model Setup ===
23
  class MultimodalBanglaClassifier(nn.Module):
24
  def __init__(self, text_model_name='sagorsarker/bangla-bert-base', num_classes=5):
25
  super(MultimodalBanglaClassifier, self).__init__()
@@ -50,60 +47,51 @@ class MultimodalBanglaClassifier(nn.Module):
50
  fused = self.transformer_fusion(fused).squeeze(1)
51
  return self.classifier(fused)
52
 
53
- # === ONNX Export Helper ===
54
  def export_to_onnx_if_needed(model):
55
- onnx_path = "bangla_disaster_model.onnx"
56
- if os.path.exists(onnx_path):
57
  return
58
  dummy_input_ids = torch.randint(0, 30522, (1, 128), dtype=torch.long)
59
  dummy_attention_mask = torch.ones((1, 128), dtype=torch.long)
60
- dummy_image = torch.randn(1, 3, 224, 224, dtype=torch.float)
61
  torch.onnx.export(
62
  model,
63
  (dummy_input_ids, dummy_attention_mask, dummy_image),
64
- onnx_path,
65
  input_names=["input_ids", "attention_mask", "image"],
66
  output_names=["output"],
67
- dynamic_axes={
68
- "input_ids": {0: "batch"},
69
- "attention_mask": {0: "batch"},
70
- "image": {0: "batch"},
71
- "output": {0: "batch"}
72
- },
73
  opset_version=14,
74
  do_constant_folding=True
75
  )
76
 
 
77
  @st.cache_resource
78
  def load_model_and_tokenizer():
79
  model = MultimodalBanglaClassifier()
80
- model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
81
  model.eval()
82
  tokenizer = AutoTokenizer.from_pretrained("sagorsarker/bangla-bert-base")
83
  export_to_onnx_if_needed(model)
84
- return model, tokenizer
85
 
 
86
  @st.cache_resource
87
  def load_onnx_session():
88
- return ort.InferenceSession("bangla_disaster_model.onnx")
89
 
 
90
  def predict_with_onnx(session, tokenizer, image, caption):
91
  transform = transforms.Compose([
92
  transforms.Resize((224, 224)),
93
  transforms.ToTensor(),
94
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
 
95
  ])
96
  image_tensor = transform(image).unsqueeze(0).numpy()
97
-
98
- encoded = tokenizer(
99
- caption,
100
- padding='max_length',
101
- truncation=True,
102
- max_length=128,
103
- return_tensors='pt'
104
- )
105
- input_ids = encoded['input_ids'].numpy()
106
- attention_mask = encoded['attention_mask'].numpy()
107
 
108
  inputs = {
109
  "input_ids": input_ids,
@@ -112,45 +100,46 @@ def predict_with_onnx(session, tokenizer, image, caption):
112
  }
113
  outputs = session.run(None, inputs)
114
  logits = outputs[0]
115
- pred_class = logits.argmax(axis=1).item()
116
- confidence_scores = logits[0].tolist()
117
- return classes[pred_class], confidence_scores
118
 
 
 
 
 
 
 
 
119
  def get_bangla_response(class_name):
120
  responses = {
121
  'HYD': "🌊 এটি একটি জলসম্পর্কিত দুর্যোগ (Hydrological Disaster)। সতর্ক থাকুন!",
122
  'MET': "🌪️ এটি একটি আবহাওয়া সংক্রান্ত দুর্যোগ (Meteorological Disaster)। সাবধানে থাকুন!",
123
  'FD': "🔥 আগুন লেগেছে! এটি একটি অগ্নিদুর্ঘটনা (Fire Disaster)। দ্রুত ব্যবস্থা নিন!",
124
- 'EQ': "🌍 ভমিকম্প শনাক্ত য়েছে (Earthquake)! নিরাপদ স্থানে যান!",
125
  'OTHD': "😌 এটা কোনো দুর্যোগ নয়। চিন্তার কিছু নেই!"
126
  }
127
- return responses.get(class_name, "🤔 শ্রেণিবিন্যাস করা যায়নি!")
128
 
129
  # === Streamlit UI ===
130
  st.set_page_config(page_title="Bangla Disaster Classifier", layout="centered")
131
  st.title("🌪️🇧🇩 Bangla Disaster Classifier")
132
  st.markdown("এই অ্যাপটি একটি multimodal deep learning মডেল ব্যবহার করে ছবির সাথে বাংলা ক্যাপশন বিশ্লেষণ করে দুর্যোগ শনাক্ত করে।")
133
 
134
- model, tokenizer = load_model_and_tokenizer()
135
  onnx_session = load_onnx_session()
136
 
137
  uploaded_file = st.file_uploader(
138
  "🖼️ একটি দুর্যোগের ছবি আপলোড করুন",
139
  type=['jpg', 'png', 'jpeg'],
140
- accept_multiple_files=False,
141
  key="disaster_image_uploader",
142
  help="ছবি আপলোড করতে এখানে ক্লিক করুন অথবা drag & drop করুন"
143
  )
144
 
145
- if uploaded_file is not None:
146
  st.success(f"✅ ছবি আপলোড সফল: {uploaded_file.name}")
147
  else:
148
  st.info("📁 অনুগ্রহ করে একটি ছবি আপলোড করুন")
149
 
150
  caption = st.text_area("✍️ বাংলায় একটি ক্যাপশন লিখুন", "")
151
 
152
- st.caption("🎯 পূর্বাভাস মোড: উচ্চ নির্ভুলতা (High Accuracy)")
153
-
154
  col1, col2 = st.columns([1, 1])
155
  submit = col1.button("🔍 পূর্বাভাস দিন")
156
  clear = col2.button("🧹 রিসেট করুন")
@@ -163,16 +152,15 @@ if submit and uploaded_file and caption:
163
  st.image(img, caption="আপলোড করা ছবি", use_container_width=True)
164
 
165
  with st.spinner("🧠 মডেল পূর্বাভাস দিচ্ছে... (Model processing...)"):
166
- progress_bar = st.progress(0, text="ছবি প্রক্রিয়াকরণ... (Processing image...)")
167
- progress_bar.progress(50, text="বিশ্লেষণ চলছে... (Running inference...)")
168
-
169
  prediction, probs = predict_with_onnx(onnx_session, tokenizer, img, caption)
 
170
 
171
- progress_bar.progress(100, text="✅ সম্পূর্ণ! (Complete!)")
172
  progress_bar.empty()
173
 
174
  st.markdown(f"### ✅ পূর্বাভাস: {get_bangla_response(prediction)}")
175
-
176
  col1, col2 = st.columns([2, 1])
177
  with col1:
178
  st.markdown(f"#### 📊 সম্ভাব্যতা: **{probs[classes.index(prediction)]:.2%}**")
@@ -188,6 +176,6 @@ if submit and uploaded_file and caption:
188
  'OTHD': 'কোনো দুর্যোগ নয়'
189
  }
190
  for i, class_code in enumerate(classes):
191
- percentage = probs[i] * 100
192
- st.write(f"**{class_names[class_code]}**: {percentage:.1f}%")
193
- st.progress(probs[i])
 
7
  import os
8
  import onnx
9
  import onnxruntime as ort
10
+ import numpy as np
11
 
12
  # === Model Path ===
13
  MODEL_PATH = "bangla_disaster_model.pth"
14
+ ONNX_PATH = "bangla_disaster_model.onnx"
15
 
16
+ # === Class Labels ===
 
 
 
 
 
17
  classes = ['HYD', 'MET', 'FD', 'EQ', 'OTHD']
18
 
19
+ # === Model Architecture (used only for export) ===
20
  class MultimodalBanglaClassifier(nn.Module):
21
  def __init__(self, text_model_name='sagorsarker/bangla-bert-base', num_classes=5):
22
  super(MultimodalBanglaClassifier, self).__init__()
 
47
  fused = self.transformer_fusion(fused).squeeze(1)
48
  return self.classifier(fused)
49
 
50
+ # === ONNX Export ===
51
  def export_to_onnx_if_needed(model):
52
+ if os.path.exists(ONNX_PATH):
 
53
  return
54
  dummy_input_ids = torch.randint(0, 30522, (1, 128), dtype=torch.long)
55
  dummy_attention_mask = torch.ones((1, 128), dtype=torch.long)
56
+ dummy_image = torch.randn(1, 3, 224, 224)
57
  torch.onnx.export(
58
  model,
59
  (dummy_input_ids, dummy_attention_mask, dummy_image),
60
+ ONNX_PATH,
61
  input_names=["input_ids", "attention_mask", "image"],
62
  output_names=["output"],
63
+ dynamic_axes={"input_ids": {0: "batch"}, "attention_mask": {0: "batch"}, "image": {0: "batch"}, "output": {0: "batch"}},
 
 
 
 
 
64
  opset_version=14,
65
  do_constant_folding=True
66
  )
67
 
68
+ # === Load Model and Tokenizer for Exporting Only ===
69
  @st.cache_resource
70
  def load_model_and_tokenizer():
71
  model = MultimodalBanglaClassifier()
72
+ model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
73
  model.eval()
74
  tokenizer = AutoTokenizer.from_pretrained("sagorsarker/bangla-bert-base")
75
  export_to_onnx_if_needed(model)
76
+ return tokenizer
77
 
78
+ # === Load ONNX Session ===
79
  @st.cache_resource
80
  def load_onnx_session():
81
+ return ort.InferenceSession(ONNX_PATH)
82
 
83
+ # === ONNX Prediction ===
84
  def predict_with_onnx(session, tokenizer, image, caption):
85
  transform = transforms.Compose([
86
  transforms.Resize((224, 224)),
87
  transforms.ToTensor(),
88
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
89
+ std=[0.229, 0.224, 0.225])
90
  ])
91
  image_tensor = transform(image).unsqueeze(0).numpy()
92
+ encoded = tokenizer(caption, padding="max_length", truncation=True, max_length=128, return_tensors="pt")
93
+ input_ids = encoded["input_ids"].numpy()
94
+ attention_mask = encoded["attention_mask"].numpy()
 
 
 
 
 
 
 
95
 
96
  inputs = {
97
  "input_ids": input_ids,
 
100
  }
101
  outputs = session.run(None, inputs)
102
  logits = outputs[0]
 
 
 
103
 
104
+ # Softmax to get probabilities
105
+ exp_logits = np.exp(logits - np.max(logits, axis=1, keepdims=True))
106
+ probs = exp_logits / np.sum(exp_logits, axis=1, keepdims=True)
107
+ pred_class = np.argmax(probs, axis=1)[0]
108
+ return classes[pred_class], probs[0].tolist()
109
+
110
+ # === Bangla Labels ===
111
  def get_bangla_response(class_name):
112
  responses = {
113
  'HYD': "🌊 এটি একটি জলসম্পর্কিত দুর্যোগ (Hydrological Disaster)। সতর্ক থাকুন!",
114
  'MET': "🌪️ এটি একটি আবহাওয়া সংক্রান্ত দুর্যোগ (Meteorological Disaster)। সাবধানে থাকুন!",
115
  'FD': "🔥 আগুন লেগেছে! এটি একটি অগ্নিদুর্ঘটনা (Fire Disaster)। দ্রুত ব্যবস্থা নিন!",
116
+ 'EQ': "🌍 ভমিকম্প শনাক্ত ���য়েছে (Earthquake)! নিরাপদ স্থানে যান!",
117
  'OTHD': "😌 এটা কোনো দুর্যোগ নয়। চিন্তার কিছু নেই!"
118
  }
119
+ return responses.get(class_name, "🤔 শ্রেণিবিন্যাস করা যায়নি")
120
 
121
  # === Streamlit UI ===
122
  st.set_page_config(page_title="Bangla Disaster Classifier", layout="centered")
123
  st.title("🌪️🇧🇩 Bangla Disaster Classifier")
124
  st.markdown("এই অ্যাপটি একটি multimodal deep learning মডেল ব্যবহার করে ছবির সাথে বাংলা ক্যাপশন বিশ্লেষণ করে দুর্যোগ শনাক্ত করে।")
125
 
126
+ tokenizer = load_model_and_tokenizer()
127
  onnx_session = load_onnx_session()
128
 
129
  uploaded_file = st.file_uploader(
130
  "🖼️ একটি দুর্যোগের ছবি আপলোড করুন",
131
  type=['jpg', 'png', 'jpeg'],
 
132
  key="disaster_image_uploader",
133
  help="ছবি আপলোড করতে এখানে ক্লিক করুন অথবা drag & drop করুন"
134
  )
135
 
136
+ if uploaded_file:
137
  st.success(f"✅ ছবি আপলোড সফল: {uploaded_file.name}")
138
  else:
139
  st.info("📁 অনুগ্রহ করে একটি ছবি আপলোড করুন")
140
 
141
  caption = st.text_area("✍️ বাংলায় একটি ক্যাপশন লিখুন", "")
142
 
 
 
143
  col1, col2 = st.columns([1, 1])
144
  submit = col1.button("🔍 পূর্বাভাস দিন")
145
  clear = col2.button("🧹 রিসেট করুন")
 
152
  st.image(img, caption="আপলোড করা ছবি", use_container_width=True)
153
 
154
  with st.spinner("🧠 মডেল পূর্বাভাস দিচ্ছে... (Model processing...)"):
155
+ progress_bar = st.progress(0, text="প্রক্রিয়াকরণ শুরু হচ্ছে...")
156
+
157
+ progress_bar.progress(50, text="বিশ্লেষণ চলছে...")
158
  prediction, probs = predict_with_onnx(onnx_session, tokenizer, img, caption)
159
+ progress_bar.progress(100, text="✅ বিশ্লেষণ সম্পন্ন!")
160
 
 
161
  progress_bar.empty()
162
 
163
  st.markdown(f"### ✅ পূর্বাভাস: {get_bangla_response(prediction)}")
 
164
  col1, col2 = st.columns([2, 1])
165
  with col1:
166
  st.markdown(f"#### 📊 সম্ভাব্যতা: **{probs[classes.index(prediction)]:.2%}**")
 
176
  'OTHD': 'কোনো দুর্যোগ নয়'
177
  }
178
  for i, class_code in enumerate(classes):
179
+ percentage = probs[i]
180
+ st.write(f"**{class_names[class_code]}**: {percentage:.1%}")
181
+ st.progress(min(max(percentage, 0.0), 1.0)) # ensure range [0.0, 1.0]