Spaces:
Sleeping
Sleeping
File size: 3,282 Bytes
8143e62 0896149 8143e62 0896149 8143e62 0896149 8143e62 0896149 8143e62 0896149 8143e62 0896149 8143e62 0896149 8143e62 0896149 8143e62 0896149 8143e62 0896149 8143e62 0896149 8143e62 0896149 8143e62 0896149 8143e62 0896149 8143e62 0896149 8143e62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
from flask import Flask, request, jsonify
from flask_cors import CORS
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
from PIL import Image
import torch
import io
# =========================
# Flask App Setup
# =========================
app = Flask(__name__)
CORS(app)
# =========================
# Device Setup
# =========================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
# =========================
# Load Model & Tokenizer
# =========================
MODEL_NAME = "nlpconnect/vit-gpt2-image-captioning"
print("Loading model...")
model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME)
feature_extractor = ViTFeatureExtractor.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model.to(device)
model.eval()
print("Model loaded successfully")
# =========================
# Caption Generation Logic
# =========================
def generate_caption(image: Image.Image):
# Ensure RGB
if image.mode != "RGB":
image = image.convert("RGB")
# Resize for ViT
image = image.resize((224, 224))
# Extract features
pixel_values = feature_extractor(
images=image,
return_tensors="pt"
).pixel_values.to(device)
# -------- First attempt: Beam Search --------
output_ids = model.generate(
pixel_values,
max_length=32,
num_beams=5,
repetition_penalty=1.2,
early_stopping=True
)
best_caption = tokenizer.decode(
output_ids[0],
skip_special_tokens=True
).strip()
# -------- Fallback: Nucleus Sampling --------
if len(best_caption.split()) < 3:
output_ids = model.generate(
pixel_values,
max_length=32,
do_sample=True,
top_p=0.9,
temperature=0.8,
repetition_penalty=1.2
)
best_caption = tokenizer.decode(
output_ids[0],
skip_special_tokens=True
).strip()
# -------- Multiple captions for robustness --------
output_ids = model.generate(
pixel_values,
max_length=32,
num_beams=5,
num_return_sequences=3,
repetition_penalty=1.2
)
captions = tokenizer.batch_decode(
output_ids,
skip_special_tokens=True
)
return best_caption, captions
# =========================
# API Endpoints
# =========================
@app.route("/", methods=["GET"])
def health_check():
return jsonify({"status": "Image Caption API is running"}), 200
@app.route("/caption", methods=["POST"])
def caption_image():
if "image" not in request.files:
return jsonify({"error": "No image provided"}), 400
try:
image_file = request.files["image"]
image = Image.open(io.BytesIO(image_file.read()))
best_caption, all_captions = generate_caption(image)
return jsonify({
"best_caption": best_caption,
"alternative_captions": all_captions
})
except Exception as e:
return jsonify({"error": str(e)}), 500
# =========================
# Run Server
# =========================
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)
|