Spaces:
Sleeping
Sleeping
| from flask import Flask, request, jsonify | |
| from flask_cors import CORS | |
| from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer | |
| from PIL import Image | |
| import torch | |
| import io | |
| # ========================= | |
| # Flask App Setup | |
| # ========================= | |
| app = Flask(__name__) | |
| CORS(app) | |
| # ========================= | |
| # Device Setup | |
| # ========================= | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print("Using device:", device) | |
| # ========================= | |
| # Load Model & Tokenizer | |
| # ========================= | |
| MODEL_NAME = "nlpconnect/vit-gpt2-image-captioning" | |
| print("Loading model...") | |
| model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME) | |
| feature_extractor = ViTFeatureExtractor.from_pretrained(MODEL_NAME) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model.to(device) | |
| model.eval() | |
| print("Model loaded successfully") | |
| # ========================= | |
| # Caption Generation Logic | |
| # ========================= | |
| def generate_caption(image: Image.Image): | |
| # Ensure RGB | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| # Resize for ViT | |
| image = image.resize((224, 224)) | |
| # Extract features | |
| pixel_values = feature_extractor( | |
| images=image, | |
| return_tensors="pt" | |
| ).pixel_values.to(device) | |
| # -------- First attempt: Beam Search -------- | |
| output_ids = model.generate( | |
| pixel_values, | |
| max_length=32, | |
| num_beams=5, | |
| repetition_penalty=1.2, | |
| early_stopping=True | |
| ) | |
| best_caption = tokenizer.decode( | |
| output_ids[0], | |
| skip_special_tokens=True | |
| ).strip() | |
| # -------- Fallback: Nucleus Sampling -------- | |
| if len(best_caption.split()) < 3: | |
| output_ids = model.generate( | |
| pixel_values, | |
| max_length=32, | |
| do_sample=True, | |
| top_p=0.9, | |
| temperature=0.8, | |
| repetition_penalty=1.2 | |
| ) | |
| best_caption = tokenizer.decode( | |
| output_ids[0], | |
| skip_special_tokens=True | |
| ).strip() | |
| # -------- Multiple captions for robustness -------- | |
| output_ids = model.generate( | |
| pixel_values, | |
| max_length=32, | |
| num_beams=5, | |
| num_return_sequences=3, | |
| repetition_penalty=1.2 | |
| ) | |
| captions = tokenizer.batch_decode( | |
| output_ids, | |
| skip_special_tokens=True | |
| ) | |
| return best_caption, captions | |
| # ========================= | |
| # API Endpoints | |
| # ========================= | |
| def health_check(): | |
| return jsonify({"status": "Image Caption API is running"}), 200 | |
| def caption_image(): | |
| if "image" not in request.files: | |
| return jsonify({"error": "No image provided"}), 400 | |
| try: | |
| image_file = request.files["image"] | |
| image = Image.open(io.BytesIO(image_file.read())) | |
| best_caption, all_captions = generate_caption(image) | |
| return jsonify({ | |
| "best_caption": best_caption, | |
| "alternative_captions": all_captions | |
| }) | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| # ========================= | |
| # Run Server | |
| # ========================= | |
| if __name__ == "__main__": | |
| app.run(host="0.0.0.0", port=7860) | |