Jaiking001 commited on
Commit
d1e6a4c
·
verified ·
1 Parent(s): 696cd98

first commit

Browse files
Files changed (4) hide show
  1. app.py +124 -0
  2. dataset.py +182 -0
  3. predict.py +158 -0
  4. train.py +123 -0
app.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision.models import resnet50, ResNet50_Weights
5
+ from torchvision import transforms
6
+ from PIL import Image
7
+
8
+
9
+ class Vocabulary:
10
+ def __init__(self):
11
+ self.itos = {}
12
+ self.stoi = {}
13
+
14
+ def load(self, stoi, itos):
15
+ self.stoi = stoi
16
+ self.itos = itos
17
+
18
+
19
+ class EncoderCNN(nn.Module):
20
+ def __init__(self, embed_size):
21
+ super(EncoderCNN, self).__init__()
22
+ resnet = resnet50(weights=ResNet50_Weights.DEFAULT)
23
+ modules = list(resnet.children())[:-1]
24
+ self.resnet = nn.Sequential(*modules)
25
+ self.linear = nn.Linear(resnet.fc.in_features, embed_size)
26
+ self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)
27
+
28
+ def forward(self, images):
29
+ with torch.no_grad():
30
+ features = self.resnet(images)
31
+ features = features.view(features.size(0), -1)
32
+ features = self.linear(features)
33
+ features = self.bn(features)
34
+ return features
35
+
36
+
37
+ class DecoderRNN(nn.Module):
38
+ def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
39
+ super(DecoderRNN, self).__init__()
40
+ self.embed = nn.Embedding(vocab_size, embed_size)
41
+ self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
42
+ self.linear = nn.Linear(hidden_size, vocab_size)
43
+
44
+ def forward(self, features, captions):
45
+ embeddings = self.embed(captions)
46
+ inputs = torch.cat((features.unsqueeze(1), embeddings), 1)
47
+ hiddens, _ = self.lstm(inputs)
48
+ outputs = self.linear(hiddens)
49
+ return outputs
50
+
51
+ def sample(self, features, vocab, max_len=30):
52
+ output_ids = []
53
+ states = None
54
+
55
+ inputs = features.unsqueeze(1)
56
+
57
+ for _ in range(max_len):
58
+ hiddens, states = self.lstm(inputs, states)
59
+ outputs = self.linear(hiddens.squeeze(1))
60
+ predicted = outputs.argmax(1)
61
+ output_ids.append(predicted.item())
62
+
63
+ if vocab.itos[predicted.item()] == "<end>":
64
+ break
65
+
66
+ inputs = self.embed(predicted).unsqueeze(1)
67
+
68
+ return output_ids
69
+
70
+ checkpoint_path = "./checkpoints/caption_model_epoch30.pth"
71
+
72
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73
+
74
+ embed_size = 256
75
+ hidden_size = 512
76
+ num_layers = 1
77
+
78
+ checkpoint = torch.load(checkpoint_path, map_location=device)
79
+
80
+ vocab = Vocabulary()
81
+ vocab.load(checkpoint['vocab_stoi'], checkpoint['vocab_itos'])
82
+ vocab_size = len(vocab.stoi)
83
+
84
+ encoder = EncoderCNN(embed_size).to(device)
85
+ decoder = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers).to(device)
86
+
87
+ encoder.load_state_dict(checkpoint['encoder_state_dict'])
88
+ decoder.load_state_dict(checkpoint['decoder_state_dict'])
89
+
90
+ encoder.eval()
91
+ decoder.eval()
92
+
93
+ transform = transforms.Compose([
94
+ transforms.Resize((224, 224)),
95
+ transforms.ToTensor()
96
+ ])
97
+
98
+ def generate_caption(image):
99
+ image = Image.fromarray(image).convert("RGB")
100
+ image = transform(image).unsqueeze(0).to(device)
101
+
102
+ with torch.no_grad():
103
+ features = encoder(image)
104
+ output_ids = decoder.sample(features, vocab)
105
+
106
+ caption = []
107
+ for idx in output_ids:
108
+ word = vocab.itos[idx]
109
+ if word == "<end>":
110
+ break
111
+ caption.append(word)
112
+
113
+ return ' '.join(caption)
114
+
115
+ demo = gr.Interface(
116
+ fn=generate_caption,
117
+ inputs=gr.Image(type="numpy"),
118
+ outputs="text",
119
+ title="Skin Disease Image Captioning",
120
+ description="Upload an image of a skin disease to generate a descriptive caption using your trained model."
121
+ )
122
+
123
+ if __name__ == "__main__":
124
+ demo.launch()
dataset.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from PIL import Image
4
+ from collections import Counter
5
+
6
+ import torch
7
+ from torch.utils.data import Dataset, DataLoader
8
+ from torch.nn.utils.rnn import pad_sequence
9
+ import torchvision.transforms as transforms
10
+ import spacy
11
+
12
+ # ===== Load spaCy English tokenizer =====
13
+ spacy_eng = spacy.load("en_core_web_sm")
14
+
15
+
16
+ class Vocabulary:
17
+ def __init__(self, freq_threshold):
18
+ """
19
+ freq_threshold: minimum word frequency to keep in vocab
20
+ """
21
+ self.freq_threshold = freq_threshold
22
+
23
+ self.itos = {0: "<pad>", 1: "<start>", 2: "<end>", 3: "<unk>"}
24
+ self.stoi = {v: k for k, v in self.itos.items()}
25
+
26
+ def __len__(self):
27
+ return len(self.itos)
28
+
29
+ @staticmethod
30
+ def tokenizer_eng(text):
31
+ """
32
+ Uses spaCy tokenizer to split sentence into list of tokens
33
+ """
34
+ return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]
35
+
36
+ def build_vocabulary(self, sentence_list):
37
+ """
38
+ Builds vocab: {word -> index} for all words with freq >= threshold
39
+ """
40
+ frequencies = Counter()
41
+ idx = 4 # Start indexing after special tokens
42
+
43
+ for sentence in sentence_list:
44
+ tokens = self.tokenizer_eng(sentence)
45
+ frequencies.update(tokens)
46
+
47
+ for word, freq in frequencies.items():
48
+ if freq >= self.freq_threshold:
49
+ self.stoi[word] = idx
50
+ self.itos[idx] = word
51
+ idx += 1
52
+
53
+ def numericalize(self, text):
54
+ """
55
+ Converts text caption to list of vocab indices
56
+ """
57
+ tokenized_text = self.tokenizer_eng(text)
58
+ return [
59
+ self.stoi.get(token, self.stoi["<unk>"])
60
+ for token in tokenized_text
61
+ ]
62
+
63
+
64
+ class CaptionDataset(Dataset):
65
+ def __init__(self, images_dir, captions_file, vocab, transform=None):
66
+ """
67
+ images_dir: path to images/train or images/val
68
+ captions_file: JSON file
69
+ vocab: Vocabulary object
70
+ transform: torchvision transform
71
+ """
72
+ self.images_dir = images_dir
73
+ self.vocab = vocab
74
+ self.transform = transform
75
+
76
+ # Load JSON
77
+ with open(captions_file, 'r') as f:
78
+ data = json.load(f)
79
+
80
+ self.images = data["images"]
81
+ self.annotations = data["annotations"]
82
+
83
+ # Create map: image_id -> file_name
84
+ self.id_to_filename = {img["id"]: img["file_name"] for img in self.images}
85
+
86
+ def __len__(self):
87
+ return len(self.annotations)
88
+
89
+ def __getitem__(self, index):
90
+ ann = self.annotations[index]
91
+ image_id = ann["image_id"]
92
+ caption = ann["caption"]
93
+
94
+ # Build image path
95
+ img_path = os.path.join(self.images_dir, self.id_to_filename[image_id])
96
+
97
+ # Open image
98
+ image = Image.open(img_path).convert("RGB")
99
+
100
+ if self.transform:
101
+ image = self.transform(image)
102
+
103
+ # Numericalize caption + add <start> and <end> tokens
104
+ numericalized_caption = [self.vocab.stoi["<start>"]]
105
+ numericalized_caption += self.vocab.numericalize(caption)
106
+ numericalized_caption.append(self.vocab.stoi["<end>"])
107
+
108
+ return image, torch.tensor(numericalized_caption)
109
+
110
+
111
+ def build_vocab_from_json(captions_file, freq_threshold):
112
+ """
113
+ Builds Vocabulary object from JSON file.
114
+ """
115
+ with open(captions_file, 'r') as f:
116
+ data = json.load(f)
117
+
118
+ all_captions = [ann["caption"] for ann in data["annotations"]]
119
+
120
+ vocab = Vocabulary(freq_threshold)
121
+ vocab.build_vocabulary(all_captions)
122
+
123
+ return vocab
124
+
125
+
126
+ def my_collate_fn(batch):
127
+ """
128
+ Custom collate_fn for variable-length captions:
129
+ Pads captions in batch to max length in batch.
130
+ """
131
+ images = []
132
+ captions = []
133
+
134
+ for img, cap in batch:
135
+ images.append(img)
136
+ captions.append(cap)
137
+
138
+ images = torch.stack(images, dim=0)
139
+ captions = pad_sequence(captions, batch_first=True, padding_value=0) # pad with <pad> token idx 0
140
+
141
+ return images, captions
142
+
143
+
144
+ # ====== Test block ======
145
+ if __name__ == "__main__":
146
+ # === Paths ===
147
+ captions_train_json = "./Dataset/annotations/captions_train.json"
148
+ images_train_dir = "./Dataset/images/train/"
149
+
150
+ # === Build vocab ===
151
+ vocab = build_vocab_from_json(captions_train_json, freq_threshold=2)
152
+ print(f"Vocab size: {len(vocab)}")
153
+
154
+ # === Transforms ===
155
+ transform = transforms.Compose([
156
+ transforms.Resize((224, 224)),
157
+ transforms.ToTensor()
158
+ ])
159
+
160
+ # === Create dataset ===
161
+ train_dataset = CaptionDataset(
162
+ images_dir=images_train_dir,
163
+ captions_file=captions_train_json,
164
+ vocab=vocab,
165
+ transform=transform
166
+ )
167
+
168
+ # === DataLoader with custom collate_fn ===
169
+ train_loader = DataLoader(
170
+ dataset=train_dataset,
171
+ batch_size=4,
172
+ shuffle=True,
173
+ collate_fn=my_collate_fn # ✅ REQUIRED for variable-length captions
174
+ )
175
+
176
+ # === Test loop ===
177
+ for idx, (images, captions) in enumerate(train_loader):
178
+ print(f"\nBatch {idx + 1}")
179
+ print("Images shape:", images.shape) # [B, 3, H, W]
180
+ print("Captions shape:", captions.shape) # [B, T] (padded)
181
+ print("Sample caption:", captions[0])
182
+ break # one batch test only
predict.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision.models import resnet50, ResNet50_Weights
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+
7
+ import os
8
+
9
+ # ===========
10
+ # Vocabulary
11
+ # ===========
12
+
13
+ class Vocabulary:
14
+ def __init__(self):
15
+ self.itos = {}
16
+ self.stoi = {}
17
+
18
+ def load(self, stoi, itos):
19
+ self.stoi = stoi
20
+ self.itos = itos
21
+
22
+ # ===========
23
+ # Encoder
24
+ # ===========
25
+
26
+ class EncoderCNN(nn.Module):
27
+ def __init__(self, embed_size):
28
+ super(EncoderCNN, self).__init__()
29
+ resnet = resnet50(weights=ResNet50_Weights.DEFAULT)
30
+ modules = list(resnet.children())[:-1]
31
+ self.resnet = nn.Sequential(*modules)
32
+ self.linear = nn.Linear(resnet.fc.in_features, embed_size)
33
+ self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)
34
+
35
+ def forward(self, images):
36
+ with torch.no_grad():
37
+ features = self.resnet(images) # [B, 2048, 1, 1]
38
+ features = features.view(features.size(0), -1) # [B, 2048]
39
+ features = self.linear(features) # [B, embed_size]
40
+ features = self.bn(features) # [B, embed_size]
41
+ return features
42
+
43
+ def __init__(self, embed_size):
44
+ super(EncoderCNN, self).__init__()
45
+ resnet = resnet50(weights=ResNet50_Weights.DEFAULT)
46
+ modules = list(resnet.children())[:-1]
47
+ self.resnet = nn.Sequential(*modules)
48
+ self.linear = nn.Linear(resnet.fc.in_features, embed_size)
49
+ self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)
50
+
51
+ def forward(self, images):
52
+ with torch.no_grad():
53
+ features = self.resnet(images).squeeze()
54
+ features = self.linear(features)
55
+ features = self.bn(features)
56
+ return features
57
+
58
+ # ===========
59
+ # Decoder
60
+ # ===========
61
+
62
+ class DecoderRNN(nn.Module):
63
+ def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
64
+ super(DecoderRNN, self).__init__()
65
+ self.embed = nn.Embedding(vocab_size, embed_size)
66
+ self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
67
+ self.linear = nn.Linear(hidden_size, vocab_size)
68
+
69
+ def forward(self, features, captions):
70
+ embeddings = self.embed(captions)
71
+ inputs = torch.cat((features.unsqueeze(1), embeddings), 1)
72
+ hiddens, _ = self.lstm(inputs)
73
+ outputs = self.linear(hiddens)
74
+ return outputs
75
+
76
+ def sample(self, features, vocab, max_len=30):
77
+ """
78
+ Generates a caption for given image features using greedy search.
79
+ """
80
+ output_ids = []
81
+ states = None
82
+
83
+ inputs = features.unsqueeze(1) # [B, 1, embed_size]
84
+
85
+ for _ in range(max_len):
86
+ hiddens, states = self.lstm(inputs, states) # [B, 1, hidden]
87
+ outputs = self.linear(hiddens.squeeze(1)) # [B, vocab_size]
88
+ predicted = outputs.argmax(1) # [B]
89
+ output_ids.append(predicted.item())
90
+
91
+ if vocab.itos[predicted.item()] == "<end>":
92
+ break
93
+
94
+ inputs = self.embed(predicted).unsqueeze(1)
95
+
96
+ return output_ids
97
+
98
+ # ===========
99
+ # Predict block
100
+ # ===========
101
+
102
+ def predict(image_path, checkpoint_path):
103
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
104
+
105
+ embed_size = 256
106
+ hidden_size = 512
107
+ num_layers = 1
108
+
109
+ # === Load checkpoint ===
110
+ checkpoint = torch.load(checkpoint_path, map_location=device)
111
+
112
+ # === Load vocab ===
113
+ vocab = Vocabulary()
114
+ vocab.load(checkpoint['vocab_stoi'], checkpoint['vocab_itos'])
115
+ vocab_size = len(vocab.stoi)
116
+
117
+ # === Load models ===
118
+ encoder = EncoderCNN(embed_size).to(device)
119
+ decoder = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers).to(device)
120
+
121
+ encoder.load_state_dict(checkpoint['encoder_state_dict'])
122
+ decoder.load_state_dict(checkpoint['decoder_state_dict'])
123
+
124
+ encoder.eval()
125
+ decoder.eval()
126
+
127
+ # === Image transform ===
128
+ transform = transforms.Compose([
129
+ transforms.Resize((224, 224)),
130
+ transforms.ToTensor()
131
+ ])
132
+
133
+ image = Image.open(image_path).convert("RGB")
134
+ image = transform(image).unsqueeze(0).to(device) # [1, 3, 224, 224]
135
+
136
+ # === Encode ===
137
+ features = encoder(image)
138
+
139
+ # === Decode ===
140
+ output_ids = decoder.sample(features, vocab)
141
+
142
+ # === Convert IDs to words ===
143
+ caption = []
144
+ for idx in output_ids:
145
+ word = vocab.itos[idx]
146
+ if word == "<end>":
147
+ break
148
+ caption.append(word)
149
+
150
+ final_caption = ' '.join(caption)
151
+ print(f"\n📝 Predicted caption: {final_caption}\n")
152
+
153
+ if __name__ == "__main__":
154
+ # ✅ Change these!
155
+ image_path = r"C:\Users\Jayasimma D\Documents\Skin_Disease_Captioning\Dataset\images\train\Albinism\Albinism2.jpg" # 🔍 your test image path
156
+ checkpoint_path = "./checkpoints/caption_model_epoch5.pth"
157
+
158
+ predict(image_path, checkpoint_path)
train.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torch.utils.data import DataLoader
6
+ from torchvision.models import resnet50, ResNet50_Weights
7
+ import torchvision.transforms as transforms
8
+
9
+ from dataset import build_vocab_from_json, CaptionDataset, my_collate_fn
10
+
11
+
12
+ class EncoderCNN(nn.Module):
13
+ def __init__(self, embed_size):
14
+ super(EncoderCNN, self).__init__()
15
+ resnet = resnet50(weights=ResNet50_Weights.DEFAULT)
16
+ modules = list(resnet.children())[:-1] # remove FC layer
17
+ self.resnet = nn.Sequential(*modules)
18
+ self.linear = nn.Linear(resnet.fc.in_features, embed_size)
19
+ self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)
20
+
21
+ def forward(self, images):
22
+ with torch.no_grad():
23
+ features = self.resnet(images).squeeze()
24
+ features = self.linear(features)
25
+ features = self.bn(features)
26
+ return features
27
+
28
+
29
+ class DecoderRNN(nn.Module):
30
+ def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
31
+ super(DecoderRNN, self).__init__()
32
+ self.embed = nn.Embedding(vocab_size, embed_size)
33
+ self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
34
+ self.linear = nn.Linear(hidden_size, vocab_size)
35
+
36
+ def forward(self, features, captions):
37
+ embeddings = self.embed(captions[:, :-1]) # Exclude <end>
38
+ inputs = torch.cat((features.unsqueeze(1), embeddings), 1) # Add image feature at t=0
39
+ hiddens, _ = self.lstm(inputs)
40
+ outputs = self.linear(hiddens)
41
+ return outputs
42
+
43
+ embed_size = 256
44
+ hidden_size = 512
45
+ num_layers = 1
46
+ learning_rate = 3e-4
47
+ num_epochs = 30
48
+ batch_size = 8
49
+
50
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
51
+
52
+ captions_train_json = "./Dataset/annotations/captions_train.json"
53
+ images_train_dir = "./Dataset/images/train/"
54
+
55
+ transform = transforms.Compose([
56
+ transforms.Resize((224, 224)),
57
+ transforms.ToTensor()
58
+ ])
59
+
60
+ vocab = build_vocab_from_json(captions_train_json, freq_threshold=2)
61
+ vocab_size = len(vocab)
62
+
63
+ train_dataset = CaptionDataset(
64
+ images_dir=images_train_dir,
65
+ captions_file=captions_train_json,
66
+ vocab=vocab,
67
+ transform=transform
68
+ )
69
+
70
+ train_loader = DataLoader(
71
+ dataset=train_dataset,
72
+ batch_size=batch_size,
73
+ shuffle=True,
74
+ collate_fn=my_collate_fn
75
+ )
76
+
77
+
78
+ encoder = EncoderCNN(embed_size).to(device)
79
+ decoder = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers).to(device)
80
+
81
+ criterion = nn.CrossEntropyLoss(ignore_index=0)
82
+ params = list(decoder.parameters()) + list(encoder.linear.parameters()) + list(encoder.bn.parameters())
83
+ optimizer = optim.Adam(params, lr=learning_rate)
84
+
85
+
86
+ encoder.train()
87
+ decoder.train()
88
+
89
+ os.makedirs("checkpoints", exist_ok=True)
90
+
91
+ for epoch in range(num_epochs):
92
+ for idx, (imgs, captions) in enumerate(train_loader):
93
+ imgs, captions = imgs.to(device), captions.to(device)
94
+
95
+ features = encoder(imgs)
96
+ outputs = decoder(features, captions)
97
+
98
+ outputs = outputs[:, 1:, :] # [B, T-1, vocab_size]
99
+
100
+ outputs = outputs.reshape(-1, vocab_size)
101
+ targets = captions[:, 1:].reshape(-1)
102
+
103
+ loss = criterion(outputs, targets)
104
+
105
+ optimizer.zero_grad()
106
+ loss.backward()
107
+ optimizer.step()
108
+
109
+ if idx % 50 == 0:
110
+ print(f"Epoch [{epoch+1}/{num_epochs}] Batch [{idx}/{len(train_loader)}] Loss: {loss.item():.4f}")
111
+
112
+ torch.save({
113
+ 'epoch': epoch + 1,
114
+ 'encoder_state_dict': encoder.state_dict(),
115
+ 'decoder_state_dict': decoder.state_dict(),
116
+ 'optimizer_state_dict': optimizer.state_dict(),
117
+ 'vocab_stoi': vocab.stoi,
118
+ 'vocab_itos': vocab.itos,
119
+ }, f"checkpoints/caption_model_epoch{epoch+1}.pth")
120
+
121
+ print(f"✅ Saved model to checkpoints/caption_model_epoch{epoch+1}.pth")
122
+
123
+ print("Training complete ✅")