Sher1988 commited on
Commit
eb55711
·
1 Parent(s): 282614a

Change structure of the project.

Browse files
app.py CHANGED
@@ -2,36 +2,33 @@ import torch
2
  import pandas as pd
3
  import streamlit as st
4
  from PIL import Image
5
- from models.encoder import EncoderCNN
6
- from models.decoder import DecoderRNN
 
7
  from utils.vocab import Vocabulary
8
- from torchvision import transforms as T
9
- from utils.helpers import VOCAB_FILE_PATH, CAPTIONS_TKN_PATH
10
- from inference import sample_with_temp, sample
11
  from utils.transforms import transforms
 
12
  import sacrebleu
13
- from huggingface_hub import hf_hub_download
14
  import os
15
 
16
  @st.cache_resource
17
  def load_models():
18
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
 
20
- repo_id = "Sher1988/image-classifier-weights"
21
- encoder_path = hf_hub_download(repo_id=repo_id, filename="encoder.pth")
22
- decoder_path = hf_hub_download(repo_id=repo_id, filename="decoder.pth")
23
 
24
  # Load captions and vocab
25
- captions = pd.read_csv(CAPTIONS_TKN_PATH).drop('tokens', axis=1)
26
- vocab = Vocabulary(load_path=VOCAB_FILE_PATH)
27
 
28
  # Initialize Models
29
  encoder = EncoderCNN(256).to(device)
30
  decoder = DecoderRNN(len(vocab), 256, 512).to(device)
31
 
32
  # Load Weights
33
- encoder.load_state_dict(torch.load(encoder_path, map_location=device))
34
- decoder.load_state_dict(torch.load(decoder_path, map_location=device))
35
 
36
  encoder.eval()
37
  decoder.eval()
@@ -41,7 +38,6 @@ def load_models():
41
 
42
  # --- Sidebar Configuration ---
43
  st.sidebar.header("Select an Example Image")
44
- IMAGE_DIR = "flickr8k/images" # Update this to your local images folder
45
 
46
  if os.path.exists(IMAGE_DIR):
47
  available_images = [f for f in os.listdir(IMAGE_DIR) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
@@ -61,8 +57,8 @@ act_caps = []
61
  caption = ''
62
  st.title("📸 AI Image Captioner")
63
 
64
- temp = st.slider("Sampling Temperature", min_value=0.1, max_value=2.0, value=0.8, step=0.1)
65
- st.info("Higher temperature = more creative/random. Lower = more predictable.")
66
 
67
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
68
 
@@ -105,7 +101,7 @@ if img is not None:
105
 
106
  if act_caps:
107
  # sacrebleu expects a list of strings for hypothesis
108
- # and a list of lists of strings for references
109
  refs = [act_caps]
110
  sys = [caption]
111
 
 
2
  import pandas as pd
3
  import streamlit as st
4
  from PIL import Image
5
+
6
+ from encoder import EncoderCNN
7
+ from decoder import DecoderRNN
8
  from utils.vocab import Vocabulary
9
+ #from torchvision import transforms as T
10
+ from utils.helpers import VOCAB_PATH, CAPTIONS_PATH, ENCODER_PATH, DECODER_PATH, IMAGE_DIR
 
11
  from utils.transforms import transforms
12
+ from inference import sample_with_temp, sample
13
  import sacrebleu
 
14
  import os
15
 
16
  @st.cache_resource
17
  def load_models():
18
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
 
 
 
 
20
 
21
  # Load captions and vocab
22
+ captions = pd.read_csv(CAPTIONS_PATH)
23
+ vocab = Vocabulary(load_path=VOCAB_PATH)
24
 
25
  # Initialize Models
26
  encoder = EncoderCNN(256).to(device)
27
  decoder = DecoderRNN(len(vocab), 256, 512).to(device)
28
 
29
  # Load Weights
30
+ encoder.load_state_dict(torch.load(ENCODER_PATH, map_location=device))
31
+ decoder.load_state_dict(torch.load(DECODER_PATH, map_location=device))
32
 
33
  encoder.eval()
34
  decoder.eval()
 
38
 
39
  # --- Sidebar Configuration ---
40
  st.sidebar.header("Select an Example Image")
 
41
 
42
  if os.path.exists(IMAGE_DIR):
43
  available_images = [f for f in os.listdir(IMAGE_DIR) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
 
57
  caption = ''
58
  st.title("📸 AI Image Captioner")
59
 
60
+ temp = st.slider("Sampling Temperature", min_value=0.0, max_value=0.8, value=0.1, step=0.1)
61
+ st.info("Higher temperature = more creative/random. Lower temperature = more predictable.")
62
 
63
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
64
 
 
101
 
102
  if act_caps:
103
  # sacrebleu expects a list of strings for hypothesis
104
+ # and a list of strings for references
105
  refs = [act_caps]
106
  sys = [caption]
107
 
{models → data}/captions_tokenized.csv RENAMED
The diff for this file is too large to render. See raw diff
 
{models → data}/vocabulary.json RENAMED
File without changes
models/decoder.py → decoder.py RENAMED
@@ -1,24 +1,24 @@
1
- import torch.nn as nn
2
-
3
-
4
- class DecoderRNN(nn.Module):
5
- def __init__(self, vocab_size, embed_size, hidden_size, num_layers=1, padding_idx=0):
6
- super(DecoderRNN, self).__init__()
7
- self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_size, padding_idx=padding_idx)
8
- self.lstm = nn.LSTM(input_size=embed_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
9
- self.linear = nn.Linear(hidden_size, vocab_size)
10
-
11
- self.init_h = nn.Linear(embed_size, hidden_size)
12
- self.init_c = nn.Linear(embed_size, hidden_size)
13
-
14
- def forward(self, features, captions, hidden=None):
15
- if hidden == None:
16
- h0 = self.init_h(features).unsqueeze(0).repeat(self.lstm.num_layers, 1, 1)
17
- c0 = self.init_c(features).unsqueeze(0).repeat(self.lstm.num_layers, 1, 1)
18
- hidden = (h0, c0)
19
- # dataflow: (B, seqlen) -> (B, hidden_size) -> (1, B, hidden_size) -> (num_layers, B, hidden_size)
20
-
21
- embeddings = self.embed(captions) # (B, seqlen) -> Training: (B, seqlen, embed_size) | Inference: (B, 1, embed_size)
22
- outputs, hidden = self.lstm(embeddings, hidden) # Training: (B, seqlen, hidden_size) | Inference: (B, 1, hidden_size)
23
- outputs = self.linear(outputs) # Training: (B, seqlen, vocab_size) | Inference: (B, 1, vocab_size)
24
  return outputs, hidden
 
1
+ import torch.nn as nn
2
+
3
+
4
+ class DecoderRNN(nn.Module):
5
+ def __init__(self, vocab_size, embed_size, hidden_size, num_layers=1, padding_idx=0):
6
+ super(DecoderRNN, self).__init__()
7
+ self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_size, padding_idx=padding_idx)
8
+ self.lstm = nn.LSTM(input_size=embed_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
9
+ self.linear = nn.Linear(hidden_size, vocab_size)
10
+
11
+ self.init_h = nn.Linear(embed_size, hidden_size)
12
+ self.init_c = nn.Linear(embed_size, hidden_size)
13
+
14
+ def forward(self, features, captions, hidden=None):
15
+ if hidden == None:
16
+ h0 = self.init_h(features).unsqueeze(0).repeat(self.lstm.num_layers, 1, 1)
17
+ c0 = self.init_c(features).unsqueeze(0).repeat(self.lstm.num_layers, 1, 1)
18
+ hidden = (h0, c0)
19
+ # dataflow: (B, seqlen) -> (B, hidden_size) -> (1, B, hidden_size) -> (num_layers, B, hidden_size)
20
+
21
+ embeddings = self.embed(captions) # (B, seqlen) -> Training: (B, seqlen, embed_size) | Inference: (B, 1, embed_size)
22
+ outputs, hidden = self.lstm(embeddings, hidden) # Training: (B, seqlen, hidden_size) | Inference: (B, 1, hidden_size)
23
+ outputs = self.linear(outputs) # Training: (B, seqlen, vocab_size) | Inference: (B, 1, vocab_size)
24
  return outputs, hidden
models/encoder.py → encoder.py RENAMED
@@ -1,24 +1,24 @@
1
- from torchvision.models import resnet50, ResNet50_Weights
2
- import torch.nn as nn
3
-
4
-
5
- class EncoderCNN(nn.Module):
6
- def __init__(self, embed_size, fine_tune=False):
7
- super(EncoderCNN, self).__init__()
8
- resnet = resnet50(weights=ResNet50_Weights.DEFAULT if fine_tune else None)
9
- for param in resnet.parameters():
10
- param.requires_grad = False
11
- if fine_tune:
12
- for param in resnet.layer4.parameters():
13
- param.requires_grad = True
14
- backbone = list(resnet.children())[:-1]
15
-
16
- self.resnet = nn.Sequential(*backbone)
17
- self.fc = nn.Linear(resnet.fc.in_features, embed_size)
18
- self.bn = nn.BatchNorm1d(num_features=embed_size, momentum=0.01)
19
-
20
- def forward(self, images): # (B, C, W, H)
21
- features = self.resnet(images) # (B, 2048, 1, 1)
22
- features = features.reshape(features.shape[0], -1) # (B, 2048*1*1) not necessay to reshape as fc layer can take any size input
23
- return self.bn(self.fc(features)) # (B, embed_size)
24
-
 
1
+ from torchvision.models import resnet50, ResNet50_Weights
2
+ import torch.nn as nn
3
+
4
+
5
+ class EncoderCNN(nn.Module):
6
+ def __init__(self, embed_size, fine_tune=False):
7
+ super(EncoderCNN, self).__init__()
8
+ resnet = resnet50(weights=ResNet50_Weights.DEFAULT if fine_tune else None)
9
+ for param in resnet.parameters():
10
+ param.requires_grad = False
11
+ if fine_tune:
12
+ for param in resnet.layer4.parameters():
13
+ param.requires_grad = True
14
+ backbone = list(resnet.children())[:-1]
15
+
16
+ self.resnet = nn.Sequential(*backbone)
17
+ self.fc = nn.Linear(resnet.fc.in_features, embed_size)
18
+ self.bn = nn.BatchNorm1d(num_features=embed_size, momentum=0.01)
19
+
20
+ def forward(self, images): # (B, C, W, H)
21
+ features = self.resnet(images) # (B, 2048, 1, 1)
22
+ features = features.reshape(features.shape[0], -1) # (B, 2048*1*1) not necessay to reshape as fc layer can take any size input
23
+ return self.bn(self.fc(features)) # (B, embed_size)
24
+
utils/helpers.py CHANGED
@@ -1,15 +1,16 @@
1
  # from enum import Enum
2
  import os
 
3
 
 
 
 
4
 
5
- DATA_DIR = 'data'
6
- CAPTIONS_FILE_PATH = os.path.join(DATA_DIR, 'flickr_data/captions.txt')
7
- IMAGES_PATH = os.path.join(DATA_DIR, 'flickr_data/Images')
8
 
 
9
 
10
- MODELS_PATH = 'models'
11
- ENCODER_MODEL_PATH = os.path.join(MODELS_PATH, 'encoder.pth')
12
- DECODER_MODEL_PATH = os.path.join(MODELS_PATH, 'decoder.pth')
13
- VOCAB_FILE_PATH = os.path.join(MODELS_PATH, 'vocabulary.json')
14
- TOKANIZED_CAPTIONS = os.path.join(MODELS_PATH, 'captions_tokenized.csv')
15
- CAPTIONS_TKN_PATH = os.path.join(MODELS_PATH, 'captions_tokenized.csv')
 
1
  # from enum import Enum
2
  import os
3
+ from huggingface_hub import hf_hub_download
4
 
5
+ repo_id = "Sher1988/image-classifier-weights"
6
+ encoder_path = hf_hub_download(repo_id=repo_id, filename="encoder.pth")
7
+ decoder_path = hf_hub_download(repo_id=repo_id, filename="decoder.pth")
8
 
 
 
 
9
 
10
+ IMAGE_DIR = 'flickr8k/images'
11
 
12
+ DATA_DIR = 'data'
13
+ ENCODER_PATH = os.path.join(DATA_DIR, 'encoder.pth')
14
+ DECODER_PATH = os.path.join(DATA_DIR, 'decoder.pth')
15
+ VOCAB_PATH = os.path.join(DATA_DIR, 'vocabulary.json')
16
+ CAPTIONS_PATH = os.path.join(DATA_DIR, 'captions_tokenized.csv')