Fred808 commited on
Commit
ac185c8
·
verified ·
1 Parent(s): 137816f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -0
app.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ from typing import Dict
5
+ from PIL import Image
6
+ from io import BytesIO
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, AutoProcessor
9
+ from fastapi import FastAPI, File, UploadFile
10
+ from fastapi.responses import JSONResponse
11
+ import uvicorn
12
+
13
+ # Disable SDPA if not supported
14
+
15
+ # ==== CONFIGURATION ====
16
+ # Florence-2 Configuration
17
+ MODEL_ID = "microsoft/Florence-2-large"
18
+ DEVICE = "cpu" # Using CPU instead of GPU
19
+
20
+ # Create FastAPI app
21
+ app = FastAPI(title="Florence-2 Image Captioning API")
22
+
23
+ # Florence-2 Model (will be loaded once)
24
+ model = None
25
+ processor = None
26
+
27
+ def log_message(message: str):
28
+ """Simple logging function"""
29
+ timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
30
+ print(f"[{timestamp}] {message}")
31
+
32
+ def load_florence_model():
33
+ """Load Florence-2 model and processor"""
34
+ global model, processor
35
+ if model is None or processor is None:
36
+ try:
37
+ log_message("[*] Loading Florence-2 model and processor...")
38
+
39
+ # Load model on CPU
40
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, trust_remote_code=True).to(DEVICE)
41
+ model.eval()
42
+
43
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
44
+ log_message("[ ] Florence-2 loaded and ready on CPU")
45
+ except Exception as e:
46
+ log_message(f"[ERROR] Failed to load Florence-2 model: {e}")
47
+ raise
48
+
49
+ def caption_image(image: Image.Image) -> str:
50
+ """Generate detailed caption for an image using Florence-2"""
51
+ if model is None or processor is None:
52
+ return "Model not loaded."
53
+
54
+ task_prompt = "<MORE_DETAILED_CAPTION>"
55
+ prompt = task_prompt
56
+
57
+ try:
58
+ # Process image
59
+ inputs = processor(
60
+ text=prompt,
61
+ images=image,
62
+ return_tensors="pt",
63
+ padding=True,
64
+ truncation=True
65
+ ).to(DEVICE)
66
+
67
+ with torch.no_grad():
68
+ generated_ids = model.generate(
69
+ input_ids=inputs["input_ids"],
70
+ pixel_values=inputs["pixel_values"],
71
+ max_new_tokens=1350,
72
+ do_sample=True,
73
+ temperature=0.7,
74
+ top_p=0.9,
75
+ num_beams=3,
76
+ )
77
+
78
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
79
+ return generated_text
80
+
81
+ except Exception as e:
82
+ log_message(f"[!] Caption generation failed: {e}")
83
+ return "Captioning error."
84
+
85
+ @app.on_event("startup")
86
+ async def startup_event():
87
+ """Load model on startup"""
88
+ load_florence_model()
89
+
90
+ @app.post("/caption")
91
+ async def create_caption(file: UploadFile = File(...)) -> Dict:
92
+ """
93
+ API endpoint to receive an image and return its caption
94
+ """
95
+ try:
96
+ log_message(f"[API] Received image: {file.filename}")
97
+
98
+ # Read and validate image
99
+ contents = await file.read()
100
+ image = Image.open(BytesIO(contents)).convert("RGB")
101
+
102
+ # Generate caption
103
+ log_message(f"[API] Generating caption for {file.filename}")
104
+ caption = caption_image(image)
105
+
106
+ log_message(f"[API] Caption generated for {file.filename}: {caption[:100]}...")
107
+
108
+ return {
109
+ "status": "success",
110
+ "filename": file.filename,
111
+ "caption": caption
112
+ }
113
+
114
+ except Exception as e:
115
+ error_msg = f"Error processing image: {str(e)}"
116
+ log_message(f"[ERROR] {error_msg}")
117
+ return JSONResponse(
118
+ status_code=500,
119
+ content={
120
+ "status": "error",
121
+ "message": error_msg
122
+ }
123
+ )
124
+
125
+ if __name__ == "__main__":
126
+ log_message("Starting Florence-2 Vision Analysis API Server")
127
+ uvicorn.run(app, host="0.0.0.0", port=8000)