sagar118's picture
Update app.py
0896149 verified
from flask import Flask, request, jsonify
from flask_cors import CORS
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
from PIL import Image
import torch
import io
# =========================
# Flask App Setup
# =========================
app = Flask(__name__)
CORS(app)
# =========================
# Device Setup
# =========================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
# =========================
# Load Model & Tokenizer
# =========================
MODEL_NAME = "nlpconnect/vit-gpt2-image-captioning"
print("Loading model...")
model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME)
feature_extractor = ViTFeatureExtractor.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model.to(device)
model.eval()
print("Model loaded successfully")
# =========================
# Caption Generation Logic
# =========================
def generate_caption(image: Image.Image):
# Ensure RGB
if image.mode != "RGB":
image = image.convert("RGB")
# Resize for ViT
image = image.resize((224, 224))
# Extract features
pixel_values = feature_extractor(
images=image,
return_tensors="pt"
).pixel_values.to(device)
# -------- First attempt: Beam Search --------
output_ids = model.generate(
pixel_values,
max_length=32,
num_beams=5,
repetition_penalty=1.2,
early_stopping=True
)
best_caption = tokenizer.decode(
output_ids[0],
skip_special_tokens=True
).strip()
# -------- Fallback: Nucleus Sampling --------
if len(best_caption.split()) < 3:
output_ids = model.generate(
pixel_values,
max_length=32,
do_sample=True,
top_p=0.9,
temperature=0.8,
repetition_penalty=1.2
)
best_caption = tokenizer.decode(
output_ids[0],
skip_special_tokens=True
).strip()
# -------- Multiple captions for robustness --------
output_ids = model.generate(
pixel_values,
max_length=32,
num_beams=5,
num_return_sequences=3,
repetition_penalty=1.2
)
captions = tokenizer.batch_decode(
output_ids,
skip_special_tokens=True
)
return best_caption, captions
# =========================
# API Endpoints
# =========================
@app.route("/", methods=["GET"])
def health_check():
return jsonify({"status": "Image Caption API is running"}), 200
@app.route("/caption", methods=["POST"])
def caption_image():
if "image" not in request.files:
return jsonify({"error": "No image provided"}), 400
try:
image_file = request.files["image"]
image = Image.open(io.BytesIO(image_file.read()))
best_caption, all_captions = generate_caption(image)
return jsonify({
"best_caption": best_caption,
"alternative_captions": all_captions
})
except Exception as e:
return jsonify({"error": str(e)}), 500
# =========================
# Run Server
# =========================
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)