File size: 5,713 Bytes
4a6156f
d24ae2d
 
 
 
 
4a6156f
22e64d4
34aa0cd
c84e518
 
 
 
22e64d4
e4ed76e
 
 
b037393
22e64d4
 
b037393
 
 
 
 
 
 
 
 
 
c84e518
 
b037393
c84e518
4a6156f
22e64d4
e4ed76e
afcf209
e4ed76e
22e64d4
afcf209
 
 
 
 
 
 
 
 
 
 
 
22e64d4
c84e518
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22e64d4
e4ed76e
4a6156f
e4ed76e
4a6156f
22e64d4
82e76f6
 
afcf209
82e76f6
 
0687ea7
f39cbba
e4ed76e
f40cf99
6dbca3d
0687ea7
f40cf99
 
 
c84e518
6dbca3d
955743d
 
c84e518
955743d
 
 
 
b037393
955743d
b037393
955743d
 
 
 
 
b037393
955743d
b037393
 
 
 
 
955743d
22e64d4
f40cf99
 
c84e518
 
f40cf99
 
 
 
afcf209
f40cf99
 
 
 
 
 
 
f39cbba
22e64d4
f40cf99
 
f39cbba
 
 
d24ae2d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# app.py — TrashTrack Turbo (compatível com ESP32 + multipart/form-data)
import os
# Evita erro do libgomp e excesso de threads no CPU
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from PIL import Image, ImageOps
import io, time, torch
import numpy as np

from transformers import AutoProcessor, AutoModel

# ==============================================
# ⚙️ Configurações
# ==============================================
MODEL_ID = "google/siglip-so400m-patch14-384"  # modelo de maior assertividade
device = "cuda" if torch.cuda.is_available() else "cpu"

# Detecta automaticamente se há torchvision
try:
    import torchvision  # noqa: F401
    USE_FAST = True
except Exception:
    USE_FAST = False

print(f"🚀 Carregando modelo {MODEL_ID} (use_fast={USE_FAST})...")
# Usa caminho "slow" se não tiver torchvision
processor = AutoProcessor.from_pretrained(MODEL_ID, use_fast=USE_FAST)
model = AutoModel.from_pretrained(
    MODEL_ID,
    dtype=torch.float16 if device == "cuda" else None
).to(device).eval()
print("✅ Modelo carregado com sucesso.")

# ==============================================
# 📋 Classes (PT + EN) — **SEM VIDRO**
# ==============================================
labels = {
    "plastico": [
        "plástico", "garrafa PET", "tampinha plástica",
        "sacola plástica", "plastic bottle"
    ],
    "papel": [
        "papel", "folha", "envelope de papel",
        "paper sheet", "paper wrapper"
    ],
    "metal": [
        "lata", "alumínio", "tampinha metálica",
        "metal cap", "can"
    ],
}

def _promptize(term: str) -> str:
    return f"centered {term} on a white background; ignore the background; classify only the object"

texts = [_promptize(t) for group in labels.values() for t in group]

# ==============================================
# 🔧 Util — recorte do foreground ignorando fundo branco
# ==============================================
def crop_foreground_ignore_white(pil: Image.Image) -> Image.Image:
    img = pil.convert("RGB")
    arr = np.array(img)
    r, g, b = arr[..., 0], arr[..., 1], arr[..., 2]
    whiteish = (r > 230) & (g > 230) & (b > 230)
    fg = ~whiteish

    if fg.sum() < 500:
        w, h = img.size
        cw, ch = int(w * 0.8), int(h * 0.8)
        left, top = (w - cw) // 2, (h - ch) // 2
        return img.crop((left, top, left + cw, top + ch))

    ys, xs = np.where(fg)
    y0, y1 = ys.min(), ys.max()
    x0, x1 = xs.min(), xs.max()

    py, px = int(0.03 * img.height), int(0.03 * img.width)
    y0 = max(0, y0 - py); y1 = min(img.height - 1, y1 + py)
    x0 = max(0, x0 - px); x1 = min(img.width - 1, x1 + px)

    return img.crop((x0, y0, x1 + 1, y1 + 1))

# ==============================================
# 🌐 App FastAPI
# ==============================================
app = FastAPI(title="TrashTrack Turbo — ESP32 Compatible")

@app.get("/")
def root():
    return {"ok": True, "model": MODEL_ID, "mode": "multipart/files[]", "classes": list(labels.keys())}

@app.post("/predict")
async def predict(files_: list[UploadFile] = File(..., alias="files[]")):
    try:
        t0 = time.time()
        results = []

        for f in files_:
            data = await f.read()
            image = Image.open(io.BytesIO(data))
            image = ImageOps.exif_transpose(image).convert("RGB")
            image = crop_foreground_ignore_white(image)

            text_inputs  = processor(text=texts,  return_tensors="pt", padding=True).to(device)
            image_inputs = processor(images=image, return_tensors="pt").to(device)

            with torch.inference_mode():
                txt_emb, img_emb = None, None

                if hasattr(model, "get_text_features"):
                    txt_emb = model.get_text_features(**text_inputs)
                if hasattr(model, "get_image_features"):
                    img_emb = model.get_image_features(**image_inputs)

                if txt_emb is None or img_emb is None:
                    joint = processor(text=texts, images=image, return_tensors="pt", padding=True).to(device)
                    out = model(**joint)
                    if txt_emb is None:
                        txt_emb = getattr(out, "text_embeds", getattr(out, "text_embeds_projected", None))
                    if img_emb is None:
                        img_emb = getattr(out, "image_embeds", getattr(out, "image_embeds_projected", None))

                img_emb = torch.nn.functional.normalize(img_emb, dim=-1)
                txt_emb = torch.nn.functional.normalize(txt_emb, dim=-1)
                logits  = (img_emb @ txt_emb.t()).squeeze(0)
                probs   = torch.softmax(logits.float().cpu(), dim=-1).tolist()

            idx, scores = 0, {}
            for key, group in labels.items():
                g = probs[idx: idx + len(group)]
                s = sum(g) / len(g)
                scores[key] = s
                idx += len(group)

            best = max(scores, key=scores.get)
            conf = round(float(scores[best]), 3)
            results.append((best, conf))

        votos = {}
        for r, c in results:
            votos.setdefault(r, []).append(c)
        final = max(votos, key=lambda k: sum(votos[k]) / len(votos[k]))
        conf = round(sum(votos[final]) / len(votos[final]), 3)
        latency = round(time.time() - t0, 2)

        print(f"[OK] {final} ({conf}) em {latency}s")
        return JSONResponse({"label": final, "conf": conf, "latency_s": latency})

    except Exception as e:
        print(f"[ERRO] {e}")
        return JSONResponse({"error": str(e)}, status_code=500)