hbatali2020 commited on
Commit
a921c54
ยท
verified ยท
1 Parent(s): aa59406

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -69
app.py CHANGED
@@ -1,95 +1,158 @@
1
- import os
2
- import torch
 
 
 
 
 
 
 
 
 
 
 
3
  import io
 
 
 
4
  from PIL import Image
5
- from fastapi import FastAPI, File, UploadFile
 
6
  from fastapi.responses import JSONResponse
7
- from fastapi.middleware.cors import CORSMiddleware
8
- from transformers import AutoProcessor, AutoModelForCausalLM, AutoConfig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- # ุฅุนุฏุงุฏ ุงู„ู…ุณุงุฑุงุช ู„ุถู…ุงู† ุงู„ุตู„ุงุญูŠุงุช ุฏุงุฎู„ Hugging Face
11
- os.environ["HF_HOME"] = "/app/cache"
 
 
12
 
13
- app = FastAPI(title="Florence-2 Image Captioning API")
 
 
 
 
 
 
 
 
 
 
14
 
15
- # ุชูุนูŠู„ CORS
16
- app.add_middleware(
17
- CORSMiddleware,
18
- allow_origins=["*"],
19
- allow_methods=["*"],
20
- allow_headers=["*"],
 
 
 
 
 
 
 
 
21
  )
22
 
23
- model_id = "microsoft/Florence-2-large"
24
- device = "cpu"
 
 
 
25
 
26
- print(f"โณ Loading {model_id} (CPU Optimized)...")
 
 
 
27
 
28
- # 1. ุชุญู…ูŠู„ ุงู„ุฅุนุฏุงุฏุงุช ูˆุฅุตู„ุงุญ ู…ุดูƒู„ุฉ forced_bos_token_id
29
- config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
30
- if not hasattr(config, 'forced_bos_token_id'):
31
- config.forced_bos_token_id = None
 
 
 
 
 
32
 
33
- # 2. ุชุญู…ูŠู„ ุงู„ู…ุนุงู„ุฌ
34
- processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
 
35
 
36
- # 3. ุชุญู…ูŠู„ ุงู„ู†ู…ูˆุฐุฌ ู…ุน ุชุนุทูŠู„ Flash Attention ู„ุงุณุชุฎุฏุงู…ู‡ ุนู„ู‰ ุงู„ู€ CPU
37
- model = AutoModelForCausalLM.from_pretrained(
38
- model_id,
39
- config=config,
40
- trust_remote_code=True,
41
- torch_dtype=torch.float32,
42
- attn_implementation="sdpa" # ๐Ÿ‘ˆ ู‡ุฐุง ุงู„ุณุทุฑ ูŠุญู„ ู…ุดูƒู„ุฉ ImportError: flash_attn
43
- ).to(device).eval()
44
 
45
- @app.post("/describe")
46
- async def describe_image(file: UploadFile = File(...)):
47
  try:
48
- # ู‚ุฑุงุกุฉ ุงู„ุตูˆุฑุฉ
49
- contents = await file.read()
50
- image = Image.open(io.BytesIO(contents)).convert("RGB")
51
-
52
- # ู…ู‡ู…ุฉ ุงู„ูˆุตู ุงู„ุชูุตูŠู„ูŠ
53
- task_prompt = "<DETAILED_CAPTION>"
54
-
55
- # ุชุฌู‡ูŠุฒ ุงู„ู…ุฏุฎู„ุงุช
56
- inputs = processor(text=task_prompt, images=image, return_tensors="pt").to(device)
57
-
58
- # ุชูˆู„ูŠุฏ ุงู„ู†ุต
 
 
 
 
 
 
 
 
 
 
 
59
  with torch.no_grad():
60
  generated_ids = model.generate(
61
  input_ids=inputs["input_ids"],
62
  pixel_values=inputs["pixel_values"],
63
- max_new_tokens=1024,
64
- num_beams=3,
65
- do_sample=False
66
- )
67
-
68
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
69
-
70
- # ู…ุนุงู„ุฌุฉ ุงู„ู†ุชุงุฆุฌ
71
- parsed_answer = processor.post_process_generation(
72
- generated_text,
73
- task=task_prompt,
74
- image_size=(image.width, image.height)
75
  )
76
 
77
- return {
78
- "description": parsed_answer[task_prompt],
79
- "status": "success"
80
- }
 
 
 
81
 
82
- except Exception as e:
83
- print(f"โŒ Error: {str(e)}")
84
- return JSONResponse(
85
- status_code=500,
86
- content={"error": "Processing failed", "details": str(e)}
87
  )
88
 
89
- @app.get("/")
90
- def health_check():
91
- return {"status": "online", "device": device, "endpoint": "/describe"}
92
 
93
  if __name__ == "__main__":
94
- import uvicorn
95
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ import sys
2
+ import types
3
+ import importlib.util
4
+
5
+ # โ”€โ”€โ”€ Mock flash_attn โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
6
+ flash_mock = types.ModuleType("flash_attn")
7
+ flash_mock.__version__ = "2.0.0"
8
+ flash_mock.__spec__ = importlib.util.spec_from_loader("flash_attn", loader=None)
9
+ sys.modules["flash_attn"] = flash_mock
10
+ sys.modules["flash_attn.flash_attn_interface"] = types.ModuleType("flash_attn.flash_attn_interface")
11
+ sys.modules["flash_attn.bert_padding"] = types.ModuleType("flash_attn.bert_padding")
12
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
13
+
14
  import io
15
+ import base64
16
+ import time
17
+ import torch
18
  from PIL import Image
19
+ from transformers import AutoProcessor, AutoModelForCausalLM
20
+ from fastapi import FastAPI, HTTPException
21
  from fastapi.responses import JSONResponse
22
+ from pydantic import BaseModel
23
+ from contextlib import asynccontextmanager
24
+ import uvicorn
25
+
26
+ # โ”€โ”€โ”€ ุงู„ู…ู‡ุงู… ุงู„ู…ุชุงุญุฉ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
27
+ AVAILABLE_TASKS = [
28
+ "<CAPTION>",
29
+ "<DETAILED_CAPTION>",
30
+ "<MORE_DETAILED_CAPTION>",
31
+ "<OD>",
32
+ "<OCR>",
33
+ "<OCR_WITH_REGION>",
34
+ ]
35
+
36
+ # โ”€โ”€โ”€ ุชุญู…ูŠู„ ุงู„ู†ู…ูˆุฐุฌ ุนู†ุฏ ุจุฏุก ุงู„ุชุทุจูŠู‚ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
37
+ MODEL = {}
38
 
39
+ @asynccontextmanager
40
+ async def lifespan(app: FastAPI):
41
+ print("๐Ÿ“ฅ ุฌุงุฑูŠ ุชุญู…ูŠู„ Florence-2-large...")
42
+ start = time.time()
43
 
44
+ processor = AutoProcessor.from_pretrained(
45
+ "microsoft/Florence-2-large",
46
+ trust_remote_code=True
47
+ )
48
+ model = AutoModelForCausalLM.from_pretrained(
49
+ "microsoft/Florence-2-large",
50
+ torch_dtype=torch.float32,
51
+ trust_remote_code=True,
52
+ attn_implementation="eager"
53
+ )
54
+ model.eval()
55
 
56
+ MODEL["processor"] = processor
57
+ MODEL["model"] = model
58
+ MODEL["device"] = "cpu"
59
+
60
+ print(f"โœ… ุงู„ู†ู…ูˆุฐุฌ ุฌุงู‡ุฒ ููŠ {time.time()-start:.1f}s")
61
+ yield
62
+ MODEL.clear()
63
+
64
+ # โ”€โ”€โ”€ FastAPI โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
65
+ app = FastAPI(
66
+ title="Florence-2-large API",
67
+ description="Microsoft Florence-2-large - Image Analysis API",
68
+ version="1.0.0",
69
+ lifespan=lifespan
70
  )
71
 
72
+ # โ”€โ”€โ”€ Schemas โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
73
+ class PredictRequest(BaseModel):
74
+ image: str # base64 encoded image
75
+ task: str = "<CAPTION>"
76
+ text_input: str = None
77
 
78
+ class PredictResponse(BaseModel):
79
+ result: dict
80
+ task: str
81
+ elapsed_seconds: float
82
 
83
+ # โ”€โ”€โ”€ Endpoints โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
84
+ @app.get("/")
85
+ def root():
86
+ return {
87
+ "status": "running",
88
+ "model": "microsoft/Florence-2-large",
89
+ "device": MODEL.get("device", "loading..."),
90
+ "available_tasks": AVAILABLE_TASKS
91
+ }
92
 
93
+ @app.get("/health")
94
+ def health():
95
+ return {"status": "ok", "model_loaded": "model" in MODEL}
96
 
97
+ @app.post("/predict", response_model=PredictResponse)
98
+ def predict(request: PredictRequest):
99
+ # ุงู„ุชุญู‚ู‚ ู…ู† ุงู„ู…ู‡ู…ุฉ
100
+ if request.task not in AVAILABLE_TASKS:
101
+ raise HTTPException(
102
+ status_code=400,
103
+ detail=f"ู…ู‡ู…ุฉ ุบูŠุฑ ุตุงู„ุญุฉ. ุงู„ู…ู‡ุงู… ุงู„ู…ุชุงุญุฉ: {AVAILABLE_TASKS}"
104
+ )
105
 
106
+ # ููƒ ุชุดููŠุฑ ุงู„ุตูˆุฑุฉ
 
107
  try:
108
+ image_bytes = base64.b64decode(request.image)
109
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
110
+ except Exception as e:
111
+ raise HTTPException(status_code=400, detail=f"ุฎุทุฃ ููŠ ุงู„ุตูˆุฑุฉ: {str(e)}")
112
+
113
+ # ุชุญุถูŠุฑ ุงู„ู€ prompt
114
+ prompt = request.task
115
+ if request.text_input:
116
+ prompt += request.text_input
117
+
118
+ # ุงู„ุงุณุชุฏู„ุงู„
119
+ try:
120
+ processor = MODEL["processor"]
121
+ model = MODEL["model"]
122
+
123
+ inputs = processor(
124
+ text=prompt,
125
+ images=image,
126
+ return_tensors="pt"
127
+ )
128
+
129
+ start = time.time()
130
  with torch.no_grad():
131
  generated_ids = model.generate(
132
  input_ids=inputs["input_ids"],
133
  pixel_values=inputs["pixel_values"],
134
+ max_new_tokens=512,
135
+ do_sample=False,
136
+ num_beams=1
 
 
 
 
 
 
 
 
 
137
  )
138
 
139
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
140
+ parsed = processor.post_process_generation(
141
+ generated_text,
142
+ task=request.task,
143
+ image_size=(image.width, image.height)
144
+ )
145
+ elapsed = time.time() - start
146
 
147
+ return PredictResponse(
148
+ result=parsed,
149
+ task=request.task,
150
+ elapsed_seconds=round(elapsed, 2)
 
151
  )
152
 
153
+ except Exception as e:
154
+ raise HTTPException(status_code=500, detail=f"ุฎุทุฃ ููŠ ุงู„ุงุณุชุฏู„ุงู„: {str(e)}")
155
+
156
 
157
  if __name__ == "__main__":
 
158
  uvicorn.run(app, host="0.0.0.0", port=7860)