cabrel09 commited on
Commit
62f4bef
·
verified ·
1 Parent(s): 362a944

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -12
app.py CHANGED
@@ -3,37 +3,73 @@ import onnxruntime as ort
3
  import numpy as np
4
  from PIL import Image
5
  from huggingface_hub import hf_hub_download
 
 
6
 
7
- model_path = hf_hub_download(repo_id="cabrel09/insect-detection-model", filename="vit_insects.onnx")
8
- hf_hub_download(repo_id="cabrel09/insect-detection-model", filename="vit_insects.onnx.data")
9
- session = ort.InferenceSession(model_path)
10
-
11
- # L'ordre crucial des classes
12
  labels = [
13
  "Fall Armyworms", "Western Corn Rootworms", "Colorado Potato Beetles", "Thrips",
14
  "Corn Earworms", "Cabbage Loopers", "Armyworms", "Brown Marmorated Stink Bugs",
15
  "Tomato Hornworms", "Citrus Canker", "Aphids", "Corn Borers", "Fruit Flies",
16
- "Africanized Honey Bees", "Spider Mites"
17
  ]
18
 
19
- def predict(img):
 
 
 
 
 
 
 
 
 
 
 
 
20
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  img = img.convert("RGB").resize((224, 224))
22
  img_array = np.array(img).astype('float32') / 255.0
23
  img_array = (img_array - 0.5) / 0.5
24
  img_array = np.transpose(img_array, (2, 0, 1))
25
  img_array = np.expand_dims(img_array, axis=0)
26
 
 
27
  outputs = session.run(None, {session.get_inputs()[0].name: img_array})
28
  logits = outputs[0][0]
29
-
30
  probs = np.exp(logits - np.max(logits))
31
  probs /= probs.sum()
32
 
33
- # On associe chaque probabilité à son nom d'insecte
34
- return {labels[i]: float(probs[i]) for i in range(len(labels))}
35
  except Exception as e:
36
  return {"error": str(e)}
37
 
38
- interface = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=5))
39
- interface.launch()
 
 
 
 
 
 
 
 
3
  import numpy as np
4
  from PIL import Image
5
  from huggingface_hub import hf_hub_download
6
+ import base64
7
+ import io
8
 
9
+ # Labels exacts
 
 
 
 
10
  labels = [
11
  "Fall Armyworms", "Western Corn Rootworms", "Colorado Potato Beetles", "Thrips",
12
  "Corn Earworms", "Cabbage Loopers", "Armyworms", "Brown Marmorated Stink Bugs",
13
  "Tomato Hornworms", "Citrus Canker", "Aphids", "Corn Borers", "Fruit Flies",
14
+ "Africanized Honey Bees (Killer Bees)", "Spider Mites"
15
  ]
16
 
17
+ # Chargement sécurisé
18
+ try:
19
+ model_path = hf_hub_download(repo_id="cabrel09/insect-detection-model", filename="vit_insects.onnx")
20
+ hf_hub_download(repo_id="cabrel09/insect-detection-model", filename="vit_insects.onnx.data")
21
+ session = ort.InferenceSession(model_path)
22
+ except Exception as e:
23
+ print(f"Erreur chargement: {e}")
24
+ session = None
25
+
26
+ def predict(img_input):
27
+ if session is None: return {"error": "Modèle non chargé"}
28
+ if img_input is None: return None
29
+
30
  try:
31
+ # GESTION HYBRIDE : Détection automatique du type d'entrée
32
+ if isinstance(img_input, str):
33
+ # Si c'est du base64 (via API)
34
+ if "base64," in img_input:
35
+ img_input = img_input.split("base64,")[1]
36
+ img_bytes = base64.b64decode(img_input)
37
+ img = Image.open(io.BytesIO(img_bytes))
38
+ elif isinstance(img_input, dict) and "url" in img_input:
39
+ # Si Gradio envoie un dictionnaire (via API structurée)
40
+ url_data = img_input["url"]
41
+ if "base64," in url_data:
42
+ url_data = url_data.split("base64,")[1]
43
+ img_bytes = base64.b64decode(url_data)
44
+ img = Image.open(io.BytesIO(img_bytes))
45
+ else:
46
+ # Si c'est déjà une image PIL (via l'interface Web)
47
+ img = img_input
48
+
49
+ # Preprocessing
50
  img = img.convert("RGB").resize((224, 224))
51
  img_array = np.array(img).astype('float32') / 255.0
52
  img_array = (img_array - 0.5) / 0.5
53
  img_array = np.transpose(img_array, (2, 0, 1))
54
  img_array = np.expand_dims(img_array, axis=0)
55
 
56
+ # Inférence
57
  outputs = session.run(None, {session.get_inputs()[0].name: img_array})
58
  logits = outputs[0][0]
 
59
  probs = np.exp(logits - np.max(logits))
60
  probs /= probs.sum()
61
 
62
+ return {labels[i]: float(probs[i]) for i in range(min(len(labels), len(probs)))}
63
+
64
  except Exception as e:
65
  return {"error": str(e)}
66
 
67
+ # Utilisation de gr.Image mais acceptation de types flexibles
68
+ demo = gr.Interface(
69
+ fn=predict,
70
+ inputs=gr.Image(type="pil"),
71
+ outputs=gr.Label(num_top_classes=5),
72
+ title="PlantPatrol Classifier"
73
+ )
74
+
75
+ demo.launch()