File size: 4,452 Bytes
7774aa1
 
 
 
299e19c
7774aa1
17342c1
7774aa1
a921c54
886527c
a921c54
886527c
38a66e9
886527c
cb453e9
7774aa1
46ba0f6
a921c54
 
5262916
e55e1b4
cb453e9
886527c
 
 
 
299e19c
886527c
 
 
5262916
 
 
a921c54
44d28de
e9452ae
52da0ad
886527c
 
f968395
52da0ad
 
e55e1b4
17342c1
 
 
 
 
 
 
 
79b3ead
 
 
e55e1b4
886527c
 
 
 
 
 
 
 
 
79b3ead
54c5c2b
e55e1b4
79b3ead
 
44d28de
79b3ead
f968395
a921c54
79b3ead
 
54c5c2b
5262916
 
 
886527c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f968395
886527c
 
5262916
 
cb453e9
5262916
 
886527c
 
 
1f8fa81
886527c
 
 
 
 
 
1d77b61
886527c
 
 
 
 
5262916
886527c
 
54c5c2b
 
5262916
886527c
 
 
5262916
 
54c5c2b
 
 
 
44d28de
a921c54
094b62f
7774aa1
094b62f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import io
import time
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModelForImageTextToText
from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager

MODEL_ID = "HuggingFaceTB/SmolVLM-500M-Instruct"

VQA_QUESTION = (
    "Is there a human being or any part of a human body in the picture? Answer yes or no"
)

MODEL_DATA = {}

@asynccontextmanager
async def lifespan(app: FastAPI):
    print(f"๐Ÿ“ฅ Loading {MODEL_ID}...")
    start = time.time()

    MODEL_DATA["processor"] = AutoProcessor.from_pretrained(
        MODEL_ID,
        size={"longest_edge": 1 * 512}  # ุชู‚ู„ูŠู„ ุงู„ุฏู‚ุฉ ู„ุชุณุฑูŠุน CPU
    )
    MODEL_DATA["model"] = AutoModelForImageTextToText.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.float32,
        _attn_implementation="eager"
    ).eval()

    print(f"โœ… Model ready in {time.time()-start:.1f}s")
    yield
    MODEL_DATA.clear()

app = FastAPI(
    title="Female Detection API - SmolVLM-500M",
    description="HuggingFaceTB/SmolVLM-500M-Instruct | VQA",
    version="1.0.0",
    lifespan=lifespan
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.get("/health")
def health():
    return {"status": "ok", "model_loaded": "model" in MODEL_DATA}

def decide(answer: str) -> tuple[str, str]:
    a = answer.strip().lower()
    if a == "no" or a.startswith("no"):
        return "allow", "model_answered_no"
    elif "yes" in a:
        return "block", "model_answered_yes"
    else:
        return "block", "unexpected_answer_blocked_for_safety"

@app.post("/analyze")
async def analyze_image(file: UploadFile = File(...)):

    if not file.content_type.startswith("image/"):
        raise HTTPException(status_code=400, detail="ุงู„ู…ู„ู ู„ูŠุณ ุตูˆุฑุฉ")

    try:
        image = Image.open(io.BytesIO(await file.read())).convert("RGB")
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"ุฎุทุฃ ููŠ ู‚ุฑุงุกุฉ ุงู„ุตูˆุฑุฉ: {str(e)}")

    try:
        processor = MODEL_DATA["processor"]
        model     = MODEL_DATA["model"]

        # โ”€โ”€โ”€ SmolVLM ุทุฑูŠู‚ุฉ ุงู„ุฑุณู…ูŠุฉ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image"},        # โ† ุจุฏูˆู† image= ู‡ู†ุง
                    {"type": "text", "text": VQA_QUESTION}
                ]
            }
        ]

        # โ”€โ”€โ”€ apply_chat_template โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
        prompt = processor.apply_chat_template(
            messages,
            add_generation_prompt=True
        )

        # โ”€โ”€โ”€ processor ูŠุณุชู‚ุจู„ ุงู„ุตูˆุฑุฉ ู‡ู†ุง โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
        inputs = processor(
            text=prompt,
            images=[image],
            return_tensors="pt"
        )

        start_time = time.time()
        with torch.no_grad():
            generated_ids = model.generate(
                **inputs,
                max_new_tokens=20,
                do_sample=False,
            )

        # โ”€โ”€โ”€ ุงุณุชุฎุฑุงุฌ ุงู„ุฅุฌุงุจุฉ ุงู„ุฌุฏูŠุฏุฉ ูู‚ุท โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
        generated_ids_trimmed = [
            out_ids[len(in_ids):]
            for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)
        ]
        answer = processor.batch_decode(
            generated_ids_trimmed,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False
        )[0].strip()

        elapsed          = round(time.time() - start_time, 2)
        decision, reason = decide(answer)

        return {
            "decision":       decision,
            "reason":         reason,
            "vqa_answer":     answer,
            "question":       VQA_QUESTION,
            "execution_time": elapsed,
            "status":         "success"
        }

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)