hbatali2020 commited on
Commit
f968395
ยท
verified ยท
1 Parent(s): 5326748

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -85
app.py CHANGED
@@ -2,16 +2,19 @@ import io
2
  import time
3
  import torch
4
  from PIL import Image
5
- from transformers import AutoProcessor, AutoModelForCausalLM
6
  from fastapi import FastAPI, HTTPException, UploadFile, File
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from contextlib import asynccontextmanager
9
 
10
- MODEL_ID = "Qwen/Qwen3.5-0.8B"
11
 
12
- VQA_QUESTION = (
13
- "Is there a woman or any part of a woman's body in this image? Answer yes or no."
14
- )
 
 
 
15
 
16
  MODEL_DATA = {}
17
 
@@ -20,28 +23,19 @@ async def lifespan(app: FastAPI):
20
  print(f"๐Ÿ“ฅ Loading {MODEL_ID}...")
21
  start = time.time()
22
 
23
- MODEL_DATA["processor"] = AutoProcessor.from_pretrained(
24
- MODEL_ID,
25
- trust_remote_code=True
26
- )
27
- MODEL_DATA["model"] = AutoModelForCausalLM.from_pretrained(
28
- MODEL_ID,
29
- torch_dtype=torch.float32,
30
- trust_remote_code=True,
31
- attn_implementation="eager",
32
- device_map="cpu"
33
  ).eval()
34
 
35
- # โ”€โ”€โ”€ DEBUG: ู†ุฑู‰ ู…ุง ูŠู‚ุจู„ู‡ ุงู„ู†ู…ูˆุฐุฌ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
36
- sig = MODEL_DATA["model"].forward.__code__.co_varnames
37
  print(f"โœ… Model ready in {time.time()-start:.1f}s")
38
- print(f"๐Ÿ“‹ Model forward args: {list(sig)[:20]}")
39
  yield
40
  MODEL_DATA.clear()
41
 
42
  app = FastAPI(
43
- title="Female Detection API - Qwen3.5-0.8B",
44
- version="1.3.0",
 
45
  lifespan=lifespan
46
  )
47
 
@@ -57,15 +51,6 @@ app.add_middleware(
57
  def health():
58
  return {"status": "ok", "model_loaded": "model" in MODEL_DATA}
59
 
60
- def decide(answer: str) -> tuple[str, str]:
61
- a = answer.strip().lower()
62
- if a == "no" or a.startswith("no"):
63
- return "allow", "model_answered_no"
64
- elif "yes" in a:
65
- return "block", "model_answered_yes"
66
- else:
67
- return "block", "unexpected_answer_blocked_for_safety"
68
-
69
  @app.post("/analyze")
70
  async def analyze_image(file: UploadFile = File(...)):
71
 
@@ -73,8 +58,7 @@ async def analyze_image(file: UploadFile = File(...)):
73
  raise HTTPException(status_code=400, detail="ุงู„ู…ู„ู ู„ูŠุณ ุตูˆุฑุฉ")
74
 
75
  try:
76
- image_bytes = await file.read()
77
- image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
78
  except Exception as e:
79
  raise HTTPException(status_code=400, detail=f"ุฎุทุฃ ููŠ ู‚ุฑุงุกุฉ ุงู„ุตูˆุฑุฉ: {str(e)}")
80
 
@@ -82,70 +66,52 @@ async def analyze_image(file: UploadFile = File(...)):
82
  processor = MODEL_DATA["processor"]
83
  model = MODEL_DATA["model"]
84
 
85
- messages = [
86
- {
87
- "role": "user",
88
- "content": [
89
- {"type": "image", "image": image},
90
- {"type": "text", "text": VQA_QUESTION}
91
- ]
92
- }
93
- ]
94
-
95
- inputs = processor.apply_chat_template(
96
- messages,
97
- tokenize=True,
98
- add_generation_prompt=True,
99
- return_dict=True,
100
  return_tensors="pt"
101
  )
102
 
103
- # โ”€โ”€โ”€ ุงู„ุญู„: ู†ุญุฐู ุงู„ู€ keys ุงู„ุชูŠ ู„ุง ูŠู‚ุจู„ู‡ุง ุงู„ู†ู…ูˆุฐุฌ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
104
- # Qwen3.5 ูŠุณุชุฎุฏู… Early Fusion โ†’ ุงู„ุตูˆุฑุฉ ู…ุฏู…ุฌุฉ ููŠ input_ids
105
- KEYS_TO_REMOVE = [
106
- "mm_token_type_ids",
107
- "pixel_values",
108
- "image_grid_thw",
109
- "pixel_values_videos",
110
- "video_grid_thw",
111
- "second_per_grid_ts",
112
- ]
113
- clean_inputs = {
114
- k: v for k, v in inputs.items()
115
- if k not in KEYS_TO_REMOVE
116
- }
117
-
118
- # โ”€โ”€โ”€ DEBUG: ู†ุฑู‰ ู…ุง ุชุจู‚ู‰ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
119
- print(f"๐Ÿ”‘ Keys sent to generate: {list(clean_inputs.keys())}")
120
-
121
  start_time = time.time()
122
  with torch.no_grad():
123
- generated_ids = model.generate(
124
- **clean_inputs,
125
- max_new_tokens=20,
126
- do_sample=False,
127
- temperature=None,
128
- top_p=None,
129
- )
130
-
131
- generated_ids_trimmed = [
132
- out_ids[len(in_ids):]
133
- for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  ]
135
- answer = processor.batch_decode(
136
- generated_ids_trimmed,
137
- skip_special_tokens=True,
138
- clean_up_tokenization_spaces=False
139
- )[0].strip()
140
 
141
- elapsed = round(time.time() - start_time, 2)
142
- decision, reason = decide(answer)
 
 
143
 
144
  return {
145
  "decision": decision,
146
- "reason": reason,
147
- "vqa_answer": answer,
148
- "question": VQA_QUESTION,
149
  "execution_time": elapsed,
150
  "status": "success"
151
  }
 
2
  import time
3
  import torch
4
  from PIL import Image
5
+ from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
6
  from fastapi import FastAPI, HTTPException, UploadFile, File
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from contextlib import asynccontextmanager
9
 
10
+ MODEL_ID = "IDEA-Research/grounding-dino-base"
11
 
12
+ # โ”€โ”€โ”€ ู†ุต ุงู„ุจุญุซ: ูŠุฌุจ ุฃู† ูŠูƒูˆู† lowercase ูˆูŠู†ุชู‡ูŠ ุจู†ู‚ุทุฉ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
13
+ # ู‚ุงุนุฏุฉ ู…ู‡ู…ุฉ ููŠ Grounding DINO!
14
+ DETECTION_TEXT = "woman . girl . female . person . human . hand . arm . face . leg . finger ."
15
+
16
+ # โ”€โ”€โ”€ ุนุชุจุฉ ุงู„ุซู‚ุฉ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
17
+ THRESHOLD = 0.3
18
 
19
  MODEL_DATA = {}
20
 
 
23
  print(f"๐Ÿ“ฅ Loading {MODEL_ID}...")
24
  start = time.time()
25
 
26
+ MODEL_DATA["processor"] = AutoProcessor.from_pretrained(MODEL_ID)
27
+ MODEL_DATA["model"] = AutoModelForZeroShotObjectDetection.from_pretrained(
28
+ MODEL_ID
 
 
 
 
 
 
 
29
  ).eval()
30
 
 
 
31
  print(f"โœ… Model ready in {time.time()-start:.1f}s")
 
32
  yield
33
  MODEL_DATA.clear()
34
 
35
  app = FastAPI(
36
+ title="Female Detection API - Grounding DINO Base",
37
+ description="IDEA-Research/grounding-dino-base | Zero-Shot Object Detection",
38
+ version="1.0.0",
39
  lifespan=lifespan
40
  )
41
 
 
51
  def health():
52
  return {"status": "ok", "model_loaded": "model" in MODEL_DATA}
53
 
 
 
 
 
 
 
 
 
 
54
  @app.post("/analyze")
55
  async def analyze_image(file: UploadFile = File(...)):
56
 
 
58
  raise HTTPException(status_code=400, detail="ุงู„ู…ู„ู ู„ูŠุณ ุตูˆุฑุฉ")
59
 
60
  try:
61
+ image = Image.open(io.BytesIO(await file.read())).convert("RGB")
 
62
  except Exception as e:
63
  raise HTTPException(status_code=400, detail=f"ุฎุทุฃ ููŠ ู‚ุฑุงุกุฉ ุงู„ุตูˆุฑุฉ: {str(e)}")
64
 
 
66
  processor = MODEL_DATA["processor"]
67
  model = MODEL_DATA["model"]
68
 
69
+ inputs = processor(
70
+ images=image,
71
+ text=DETECTION_TEXT,
 
 
 
 
 
 
 
 
 
 
 
 
72
  return_tensors="pt"
73
  )
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  start_time = time.time()
76
  with torch.no_grad():
77
+ outputs = model(**inputs)
78
+
79
+ # โ”€โ”€โ”€ post process โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
80
+ results = processor.post_process_grounded_object_detection(
81
+ outputs,
82
+ inputs.input_ids,
83
+ threshold=THRESHOLD,
84
+ text_threshold=THRESHOLD,
85
+ target_sizes=[image.size[::-1]] # (height, width)
86
+ )[0]
87
+
88
+ elapsed = round(time.time() - start_time, 2)
89
+
90
+ boxes = results["boxes"].tolist()
91
+ scores = results["scores"].tolist()
92
+ labels = results["labels"]
93
+
94
+ # โ”€โ”€โ”€ ุชุฌู…ูŠุน ุงู„ู†ุชุงุฆุฌ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
95
+ detections = [
96
+ {
97
+ "label": label,
98
+ "confidence": round(score, 3),
99
+ "bbox": [round(x, 1) for x in box]
100
+ }
101
+ for label, score, box in zip(labels, scores, boxes)
102
+ if score >= THRESHOLD
103
  ]
 
 
 
 
 
104
 
105
+ # โ”€โ”€โ”€ ุงู„ู‚ุฑุงุฑ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
106
+ decision = "block" if len(detections) > 0 else "allow"
107
+ summary = f"yes detected: {', '.join(set(d['label'] for d in detections))}" \
108
+ if detections else "no detected human body"
109
 
110
  return {
111
  "decision": decision,
112
+ "summary": summary,
113
+ "detected_count": len(detections),
114
+ "detections": detections,
115
  "execution_time": elapsed,
116
  "status": "success"
117
  }