sagar118 commited on
Commit
8143e62
·
verified ·
1 Parent(s): 88079cc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -0
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ from flask_cors import CORS
3
+ from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
4
+ from PIL import Image
5
+ import torch
6
+ import io
7
+
8
+ app = Flask(__name__)
9
+ CORS(app)
10
+
11
+ print("Loading model...")
12
+
13
+ model = VisionEncoderDecoderModel.from_pretrained(
14
+ "nlpconnect/vit-gpt2-image-captioning"
15
+ )
16
+ feature_extractor = ViTFeatureExtractor.from_pretrained(
17
+ "nlpconnect/vit-gpt2-image-captioning"
18
+ )
19
+ tokenizer = AutoTokenizer.from_pretrained(
20
+ "nlpconnect/vit-gpt2-image-captioning"
21
+ )
22
+
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ model.to(device)
25
+
26
+ print("Model loaded successfully")
27
+
28
+ def predict_caption(image: Image.Image):
29
+ if image.mode != "RGB":
30
+ image = image.convert("RGB")
31
+
32
+ pixel_values = feature_extractor(
33
+ images=[image], return_tensors="pt"
34
+ ).pixel_values.to(device)
35
+
36
+ output_ids = model.generate(
37
+ pixel_values,
38
+ max_length=16,
39
+ num_beams=4
40
+ )
41
+
42
+ preds = tokenizer.batch_decode(
43
+ output_ids, skip_special_tokens=True
44
+ )
45
+ return preds[0].strip()
46
+
47
+ @app.route("/caption", methods=["POST"])
48
+ def caption():
49
+ if "image" not in request.files:
50
+ return jsonify({"error": "No image provided"}), 400
51
+
52
+ image_file = request.files["image"]
53
+ image = Image.open(io.BytesIO(image_file.read()))
54
+
55
+ try:
56
+ caption = predict_caption(image)
57
+ return jsonify({"caption": caption})
58
+ except Exception as e:
59
+ return jsonify({"error": str(e)}), 500
60
+
61
+ @app.route("/")
62
+ def health():
63
+ return "Image Caption API is running"
64
+
65
+ if __name__ == "__main__":
66
+ app.run(host="0.0.0.0", port=7860)