pr0ximaCent commited on
Commit
1350413
·
verified ·
1 Parent(s): 0c6477f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -0
app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import streamlit as st
3
+ import torch
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ from transformers import AutoTokenizer, AutoModel
7
+ import torch.nn as nn
8
+
9
+ # === Model Setup ===
10
+ class MultimodalBanglaClassifier(nn.Module):
11
+ def __init__(self, text_model_name='sagorsarker/bangla-bert-base', num_classes=5):
12
+ super(MultimodalBanglaClassifier, self).__init__()
13
+ self.text_model = AutoModel.from_pretrained(text_model_name)
14
+ for param in self.text_model.encoder.layer[:6].parameters():
15
+ param.requires_grad = False
16
+
17
+ from torchvision.models import efficientnet_b3, EfficientNet_B3_Weights
18
+ self.image_model = efficientnet_b3(weights=EfficientNet_B3_Weights.IMAGENET1K_V1)
19
+ self.image_model.classifier = nn.Identity()
20
+
21
+ self.proj = nn.Linear(768 + 1536, 512)
22
+ self.transformer_fusion = nn.TransformerEncoder(
23
+ nn.TransformerEncoderLayer(d_model=512, nhead=4, batch_first=True),
24
+ num_layers=2
25
+ )
26
+ self.classifier = nn.Sequential(
27
+ nn.Linear(512, 256),
28
+ nn.ReLU(),
29
+ nn.Dropout(0.3),
30
+ nn.Linear(256, num_classes)
31
+ )
32
+
33
+ def forward(self, input_ids, attention_mask, image):
34
+ text_feat = self.text_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
35
+ image_feat = self.image_model(image)
36
+ fused = self.proj(torch.cat((text_feat, image_feat), dim=1)).unsqueeze(1)
37
+ fused = self.transformer_fusion(fused).squeeze(1)
38
+ return self.classifier(fused)
39
+
40
+ @st.cache_resource
41
+ def load_model():
42
+ model = MultimodalBanglaClassifier()
43
+ model.load_state_dict(torch.load("bangla_disaster_model.pth", map_location=torch.device('cpu')))
44
+ model.eval()
45
+ return model
46
+
47
+ def predict(model, tokenizer, image, caption):
48
+ transform = transforms.Compose([
49
+ transforms.Resize((224, 224)),
50
+ transforms.ToTensor(),
51
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
52
+ std=[0.229, 0.224, 0.225])
53
+ ])
54
+
55
+ image = transform(image).unsqueeze(0)
56
+
57
+ encoded = tokenizer(
58
+ caption,
59
+ padding='max_length',
60
+ truncation=True,
61
+ max_length=128,
62
+ return_tensors='pt'
63
+ )
64
+
65
+ with torch.no_grad():
66
+ output = model(
67
+ input_ids=encoded['input_ids'],
68
+ attention_mask=encoded['attention_mask'],
69
+ image=image
70
+ )
71
+ pred_class = output.argmax(dim=1).item()
72
+ classes = ['HYD', 'MET', 'FD', 'EQ', 'OTHD']
73
+ return classes[pred_class]
74
+
75
+ st.title("🌪️ Bangla Disaster Classifier")
76
+
77
+ uploaded_file = st.file_uploader("Upload an image", type=['jpg', 'png', 'jpeg'])
78
+ caption = st.text_area("Enter Bangla caption", "")
79
+
80
+ if uploaded_file and caption:
81
+ img = Image.open(uploaded_file).convert("RGB")
82
+ st.image(img, caption="Uploaded Image", use_column_width=True)
83
+
84
+ with st.spinner("Predicting..."):
85
+ tokenizer = AutoTokenizer.from_pretrained("sagorsarker/bangla-bert-base")
86
+ model = load_model()
87
+ prediction = predict(model, tokenizer, img, caption)
88
+
89
+ st.success(f"✅ Predicted Disaster Class: **{prediction}**")