hbatali2020 commited on
Commit
79b3ead
ยท
verified ยท
1 Parent(s): e7c3f4f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -59
app.py CHANGED
@@ -13,15 +13,19 @@ sys.modules["flash_attn.bert_padding"] = types.ModuleType("flash_attn.bert_paddi
13
 
14
  import io
15
  import time
 
16
  import torch
17
  from PIL import Image
18
  from transformers import AutoProcessor, AutoModelForCausalLM
19
  from fastapi import FastAPI, HTTPException, UploadFile, File
 
 
 
20
  from contextlib import asynccontextmanager
 
21
 
22
  MODEL_ID = "microsoft/Florence-2-large-ft"
23
 
24
- # โ”€โ”€โ”€ ุงู„ุณุคุงู„ ุงู„ุฃุตู„ูŠ + ุชุฃูƒูŠุฏ ุนู„ู‰ ุงู„ูŠุฏ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
25
  VQA_QUESTION = (
26
  "Is there a woman or any part of a woman's body in this image? Answer yes or no only."
27
  )
@@ -46,82 +50,107 @@ async def lifespan(app: FastAPI):
46
  MODEL_DATA.clear()
47
 
48
  app = FastAPI(
49
- title="Female Detection API - VQA",
50
- description="Florence-2-large-ft | VQA",
51
- version="4.3.0",
52
  lifespan=lifespan
53
  )
54
 
55
- @app.get("/health")
56
- def health():
57
- return {"status": "ok", "model_loaded": "model" in MODEL_DATA}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- def decide(answer: str) -> tuple[str, str]:
60
- """
61
- - "no" โ†’ allow โœ…
62
- - "yes" โ†’ block ๐Ÿ”ด
63
- - ุฃูŠ ุดูŠุก ุขุฎุฑ โ†’ block ๐Ÿ”ด ู„ู„ุฃู…ุงู†
64
- """
65
- a = answer.strip().lower()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  if a == "no" or a.startswith("no"):
67
- return "allow", "model_answered_no"
68
  elif "yes" in a:
69
- return "block", "model_answered_yes"
70
  else:
71
- return "block", "unexpected_answer_blocked_for_safety"
72
 
73
- @app.post("/analyze")
74
- async def analyze_image(file: UploadFile = File(...)):
 
 
 
 
 
75
 
76
- if not file.content_type.startswith("image/"):
77
- raise HTTPException(status_code=400, detail="ุงู„ู…ู„ู ู„ูŠุณ ุตูˆุฑุฉ")
 
 
78
 
 
 
 
 
79
  try:
80
- image = Image.open(io.BytesIO(await file.read())).convert("RGB")
 
 
 
81
  except Exception as e:
82
- raise HTTPException(status_code=400, detail=f"ุฎุทุฃ ููŠ ู‚ุฑุงุกุฉ ุงู„ุตูˆุฑุฉ: {str(e)}")
83
 
84
  try:
85
- processor = MODEL_DATA["processor"]
86
- model = MODEL_DATA["model"]
87
-
88
- task = "<VQA>"
89
- prompt = f"{task}{VQA_QUESTION}"
90
-
91
- inputs = processor(text=prompt, images=image, return_tensors="pt")
92
-
93
- start_time = time.time()
94
- with torch.no_grad():
95
- generated_ids = model.generate(
96
- input_ids=inputs["input_ids"],
97
- pixel_values=inputs["pixel_values"],
98
- max_new_tokens=10,
99
- num_beams=3,
100
- do_sample=False
101
- )
102
-
103
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
104
- parsed = processor.post_process_generation(
105
- generated_text,
106
- task=task,
107
- image_size=(image.width, image.height)
108
- )
109
- elapsed = round(time.time() - start_time, 2)
110
 
111
- answer = parsed.get(task, "").strip()
112
- decision, reason = decide(answer)
113
 
114
- return {
115
- "decision": decision,
116
- "reason": reason,
117
- "vqa_answer": answer,
118
- "question": VQA_QUESTION,
119
- "execution_time": elapsed,
120
- "status": "success"
121
- }
122
 
 
 
123
  except Exception as e:
124
- raise HTTPException(status_code=500, detail=str(e))
 
 
125
 
126
 
127
  if __name__ == "__main__":
 
13
 
14
  import io
15
  import time
16
+ import httpx
17
  import torch
18
  from PIL import Image
19
  from transformers import AutoProcessor, AutoModelForCausalLM
20
  from fastapi import FastAPI, HTTPException, UploadFile, File
21
+ from fastapi.middleware.cors import CORSMiddleware
22
+ from fastapi.responses import JSONResponse
23
+ from pydantic import BaseModel
24
  from contextlib import asynccontextmanager
25
+ from typing import Optional
26
 
27
  MODEL_ID = "microsoft/Florence-2-large-ft"
28
 
 
29
  VQA_QUESTION = (
30
  "Is there a woman or any part of a woman's body in this image? Answer yes or no only."
31
  )
 
50
  MODEL_DATA.clear()
51
 
52
  app = FastAPI(
53
+ title="AI Shield - Female Detection API",
54
+ description="Florence-2-large-ft | VQA | Compatible with AI Shield Chrome Extension",
55
+ version="5.0.0",
56
  lifespan=lifespan
57
  )
58
 
59
+ # โ”€โ”€โ”€ CORS: ุถุฑูˆุฑูŠ ู„ุฅุถุงูุฉ Chrome โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
60
+ app.add_middleware(
61
+ CORSMiddleware,
62
+ allow_origins=["*"],
63
+ allow_credentials=True,
64
+ allow_methods=["*"],
65
+ allow_headers=["*"],
66
+ )
67
+
68
+ # โ”€โ”€โ”€ Schemas โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
69
+ class ImageUrlRequest(BaseModel):
70
+ image_url: str # ู…ู† ุฅุถุงูุฉ Chrome
71
+
72
+ # โ”€โ”€โ”€ ุฏุงู„ุฉ ุงู„ุชุญู„ูŠู„ ุงู„ู…ุดุชุฑูƒุฉ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
73
+ def analyze_image(image: Image.Image) -> dict:
74
+ processor = MODEL_DATA["processor"]
75
+ model = MODEL_DATA["model"]
76
+
77
+ task = "<VQA>"
78
+ prompt = f"{task}{VQA_QUESTION}"
79
+
80
+ inputs = processor(text=prompt, images=image, return_tensors="pt")
81
 
82
+ start_time = time.time()
83
+ with torch.no_grad():
84
+ generated_ids = model.generate(
85
+ input_ids=inputs["input_ids"],
86
+ pixel_values=inputs["pixel_values"],
87
+ max_new_tokens=10,
88
+ num_beams=3,
89
+ do_sample=False
90
+ )
91
+
92
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
93
+ parsed = processor.post_process_generation(
94
+ generated_text,
95
+ task=task,
96
+ image_size=(image.width, image.height)
97
+ )
98
+ elapsed = round(time.time() - start_time, 2)
99
+ answer = parsed.get(task, "").strip()
100
+
101
+ # โ”€โ”€โ”€ ู…ู†ุทู‚ ุงู„ู‚ุฑุงุฑ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
102
+ a = answer.lower()
103
  if a == "no" or a.startswith("no"):
104
+ decision, reason = "ALLOW", "model_answered_no"
105
  elif "yes" in a:
106
+ decision, reason = "BLOCK", "model_answered_yes"
107
  else:
108
+ decision, reason = "BLOCK", "unexpected_answer_blocked_for_safety"
109
 
110
+ return {
111
+ "decision": decision, # ALLOW | BLOCK (ุจุงู„ุญุฑูˆู ุงู„ูƒุจูŠุฑุฉ ู„ุชุชูˆุงูู‚ ู…ุน ุงู„ุฅุถุงูุฉ)
112
+ "reason": reason,
113
+ "vqa_answer": answer,
114
+ "execution_time": elapsed,
115
+ "status": "success"
116
+ }
117
 
118
+ # โ”€โ”€โ”€ Health Check โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
119
+ @app.get("/health")
120
+ def health():
121
+ return {"status": "ok", "model_loaded": "model" in MODEL_DATA}
122
 
123
+ # โ”€โ”€โ”€ Endpoint 1: ู…ู† ุฅุถุงูุฉ Chrome (image_url) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
124
+ # background.js ูŠุฑุณู„: POST /analyze {"image_url": "https://..."}
125
+ @app.post("/analyze")
126
+ async def analyze_from_url(request: ImageUrlRequest):
127
  try:
128
+ async with httpx.AsyncClient(timeout=30) as client:
129
+ response = await client.get(request.image_url)
130
+ response.raise_for_status()
131
+ image_bytes = response.content
132
  except Exception as e:
133
+ raise HTTPException(status_code=400, detail=f"ูุดู„ ุชุญู…ูŠู„ ุงู„ุตูˆุฑุฉ ู…ู† URL: {str(e)}")
134
 
135
  try:
136
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
137
+ except Exception as e:
138
+ raise HTTPException(status_code=400, detail=f"ุฎุทุฃ ููŠ ู‚ุฑุงุกุฉ ุงู„ุตูˆุฑุฉ: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
+ return analyze_image(image)
 
141
 
142
+ # โ”€โ”€โ”€ Endpoint 2: ุงุฎุชุจุงุฑ ูŠุฏูˆูŠ (ุฑูุน ู…ู„ู) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
143
+ @app.post("/analyze-file")
144
+ async def analyze_from_file(file: UploadFile = File(...)):
145
+ if not file.content_type.startswith("image/"):
146
+ raise HTTPException(status_code=400, detail="ุงู„ู…ู„ู ู„ูŠุณ ุตูˆุฑุฉ")
 
 
 
147
 
148
+ try:
149
+ image = Image.open(io.BytesIO(await file.read())).convert("RGB")
150
  except Exception as e:
151
+ raise HTTPException(status_code=400, detail=f"ุฎุทุฃ ููŠ ู‚ุฑุงุกุฉ ุงู„ุตูˆุฑุฉ: {str(e)}")
152
+
153
+ return analyze_image(image)
154
 
155
 
156
  if __name__ == "__main__":