from flask import Flask, request, jsonify from flask_cors import CORS from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer from PIL import Image import torch import io # ========================= # Flask App Setup # ========================= app = Flask(__name__) CORS(app) # ========================= # Device Setup # ========================= device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Using device:", device) # ========================= # Load Model & Tokenizer # ========================= MODEL_NAME = "nlpconnect/vit-gpt2-image-captioning" print("Loading model...") model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME) feature_extractor = ViTFeatureExtractor.from_pretrained(MODEL_NAME) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model.to(device) model.eval() print("Model loaded successfully") # ========================= # Caption Generation Logic # ========================= def generate_caption(image: Image.Image): # Ensure RGB if image.mode != "RGB": image = image.convert("RGB") # Resize for ViT image = image.resize((224, 224)) # Extract features pixel_values = feature_extractor( images=image, return_tensors="pt" ).pixel_values.to(device) # -------- First attempt: Beam Search -------- output_ids = model.generate( pixel_values, max_length=32, num_beams=5, repetition_penalty=1.2, early_stopping=True ) best_caption = tokenizer.decode( output_ids[0], skip_special_tokens=True ).strip() # -------- Fallback: Nucleus Sampling -------- if len(best_caption.split()) < 3: output_ids = model.generate( pixel_values, max_length=32, do_sample=True, top_p=0.9, temperature=0.8, repetition_penalty=1.2 ) best_caption = tokenizer.decode( output_ids[0], skip_special_tokens=True ).strip() # -------- Multiple captions for robustness -------- output_ids = model.generate( pixel_values, max_length=32, num_beams=5, num_return_sequences=3, repetition_penalty=1.2 ) captions = tokenizer.batch_decode( output_ids, skip_special_tokens=True ) return best_caption, captions # ========================= # API Endpoints # ========================= @app.route("/", methods=["GET"]) def health_check(): return jsonify({"status": "Image Caption API is running"}), 200 @app.route("/caption", methods=["POST"]) def caption_image(): if "image" not in request.files: return jsonify({"error": "No image provided"}), 400 try: image_file = request.files["image"] image = Image.open(io.BytesIO(image_file.read())) best_caption, all_captions = generate_caption(image) return jsonify({ "best_caption": best_caption, "alternative_captions": all_captions }) except Exception as e: return jsonify({"error": str(e)}), 500 # ========================= # Run Server # ========================= if __name__ == "__main__": app.run(host="0.0.0.0", port=7860)