File size: 3,282 Bytes
8143e62
 
 
 
 
 
 
0896149
 
 
8143e62
 
 
0896149
 
 
 
 
8143e62
0896149
 
 
 
8143e62
0896149
 
 
 
8143e62
0896149
 
8143e62
 
0896149
 
 
 
 
8143e62
 
 
0896149
 
 
 
8143e62
0896149
 
8143e62
 
0896149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8143e62
 
0896149
 
 
 
8143e62
 
0896149
 
 
8143e62
0896149
 
 
 
 
 
 
 
 
8143e62
 
0896149
8143e62
 
 
 
0896149
 
 
 
 
 
 
 
 
 
8143e62
 
 
0896149
 
 
8143e62
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from flask import Flask, request, jsonify
from flask_cors import CORS
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
from PIL import Image
import torch
import io

# =========================
# Flask App Setup
# =========================
app = Flask(__name__)
CORS(app)

# =========================
# Device Setup
# =========================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# =========================
# Load Model & Tokenizer
# =========================
MODEL_NAME = "nlpconnect/vit-gpt2-image-captioning"

print("Loading model...")
model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME)
feature_extractor = ViTFeatureExtractor.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

model.to(device)
model.eval()
print("Model loaded successfully")

# =========================
# Caption Generation Logic
# =========================
def generate_caption(image: Image.Image):
    # Ensure RGB
    if image.mode != "RGB":
        image = image.convert("RGB")

    # Resize for ViT
    image = image.resize((224, 224))

    # Extract features
    pixel_values = feature_extractor(
        images=image,
        return_tensors="pt"
    ).pixel_values.to(device)

    # -------- First attempt: Beam Search --------
    output_ids = model.generate(
        pixel_values,
        max_length=32,
        num_beams=5,
        repetition_penalty=1.2,
        early_stopping=True
    )

    best_caption = tokenizer.decode(
        output_ids[0],
        skip_special_tokens=True
    ).strip()

    # -------- Fallback: Nucleus Sampling --------
    if len(best_caption.split()) < 3:
        output_ids = model.generate(
            pixel_values,
            max_length=32,
            do_sample=True,
            top_p=0.9,
            temperature=0.8,
            repetition_penalty=1.2
        )
        best_caption = tokenizer.decode(
            output_ids[0],
            skip_special_tokens=True
        ).strip()

    # -------- Multiple captions for robustness --------
    output_ids = model.generate(
        pixel_values,
        max_length=32,
        num_beams=5,
        num_return_sequences=3,
        repetition_penalty=1.2
    )

    captions = tokenizer.batch_decode(
        output_ids,
        skip_special_tokens=True
    )

    return best_caption, captions

# =========================
# API Endpoints
# =========================
@app.route("/", methods=["GET"])
def health_check():
    return jsonify({"status": "Image Caption API is running"}), 200

@app.route("/caption", methods=["POST"])
def caption_image():
    if "image" not in request.files:
        return jsonify({"error": "No image provided"}), 400

    try:
        image_file = request.files["image"]
        image = Image.open(io.BytesIO(image_file.read()))

        best_caption, all_captions = generate_caption(image)

        return jsonify({
            "best_caption": best_caption,
            "alternative_captions": all_captions
        })

    except Exception as e:
        return jsonify({"error": str(e)}), 500

# =========================
# Run Server
# =========================
if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860)