MODLI commited on
Commit
208a0ac
·
verified ·
1 Parent(s): d32e8d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -45
app.py CHANGED
@@ -3,59 +3,46 @@ from transformers import pipeline
3
  from PIL import Image
4
  import numpy as np
5
  import cv2
6
- from categories import FASHION_CATEGORIES # Importe ta liste de catégories
7
 
8
  # --- Configuration et Chargement des Modèles ---
9
- # Ces modèles sont chargés une seule fois au démarrage de l'application
10
  print("Loading segmentation model... This might take a minute.")
11
- # Modèle pour détourer et isoler le vêtement
12
  seg_pipe = pipeline("image-segmentation", model="mattmdjaga/segformer_b2_clothes")
13
 
14
  print("Loading fashion classification model...")
15
- # Modèle de classification spécialisé dans la mode (FashionCLIP)
16
- # Un modèle qui fait de la classification d'images de mode
17
- class_pipe = pipeline("image-classification", model="asafaya/convnext-tiny-224-fashion")
18
  print("Models loaded successfully!")
19
 
20
  # --- Fonctions de Prétraitement ---
21
  def get_largest_segment(segments):
22
- """Trouve le segment le plus grand (le vêtement principal) parmi les résultats de segmentation."""
23
  largest_area = 0
24
  largest_segment = None
25
  for segment in segments:
26
  mask = segment['mask']
27
- area = np.sum(mask) # Calcule la surface du masque
28
  if area > largest_area:
29
  largest_area = area
30
  largest_segment = segment
31
  return largest_segment
32
 
33
  def crop_to_mask(image, mask):
34
- """
35
- Recadre l'image sur la zone délimitée par le masque.
36
- Cela permet de supprimer tout le fond inutile.
37
- """
38
  mask_np = np.array(mask)
39
- # Trouve les coordonnées des pixels blancs du masque
40
  y, x = np.where(mask_np > 0)
41
  if len(x) == 0 or len(y) == 0:
42
- return image # Retourne l'image originale si le masque est vide
43
 
44
  x_min, x_max = np.min(x), np.max(x)
45
  y_min, y_max = np.min(y), np.max(y)
46
 
47
- # Recadre l'image
48
  cropped_image = image.crop((x_min, y_min, x_max, y_max))
49
  return cropped_image
50
 
51
- # --- Fonction Principale de Classification ---
52
  def classify_image(input_image):
53
- """
54
- Fonction appelée par Gradio quand l'utilisateur upload une image.
55
- 1. Segmentation pour isoler le vêtement.
56
- 2. Classification sur le vêtement isolé.
57
- """
58
- # Convertit l'image Gradio en PIL Image
59
  pil_image = Image.fromarray(input_image)
60
 
61
  # ÉTAPE 1: SEGMENTATION
@@ -63,26 +50,22 @@ def classify_image(input_image):
63
  main_item = get_largest_segment(segments)
64
 
65
  if main_item is None:
66
- return "No clothing item detected. Please try another image."
67
 
68
- # ÉTAPE 2: PRÉTRAITEMENT - On isole le vêtement
69
  isolated_image = crop_to_mask(pil_image, main_item['mask'])
70
 
71
- # ÉTAPE 3: CLASSIFICATION FINE
72
- # On utilise le modèle FashionCLIP pour comparer l'image à notre liste de catégories
73
  predictions = class_pipe(isolated_image, candidate_labels=FASHION_CATEGORIES)
74
 
75
- # Formatage des résultats pour l'affichage
76
  result_text = "Classification Results (on isolated garment):\n"
77
- for pred in predictions:
78
- # Affiche le score en pourcentage
79
  result_text += f"- {pred['label']}: {pred['score']*100:.2f}%\n"
80
 
81
- # On retourne aussi l'image isolée pour debugger et voir ce que l'IA a vraiment analysé
82
  return result_text, isolated_image
83
 
84
  # --- Interface Gradio ---
85
- # Création de l'interface utilisateur
86
  with gr.Blocks(title="Fashion Category Classifier") as demo:
87
  gr.Markdown("# 👗 Fashion Category Classifier")
88
  gr.Markdown("Upload a picture of a clothing item. The AI will isolate it and classify it.")
@@ -94,25 +77,13 @@ with gr.Blocks(title="Fashion Category Classifier") as demo:
94
 
95
  with gr.Column():
96
  label_output = gr.Textbox(label="Classification Results", lines=6)
97
- image_output = gr.Image(label="Isolated Garment (what the AI analyzed)", type="pil")
98
 
99
- # Lie le bouton à la fonction
100
  classify_btn.click(
101
  fn=classify_image,
102
  inputs=image_input,
103
  outputs=[label_output, image_output]
104
  )
105
-
106
- # Quelques exemples pour que l'utilisateur teste facilement
107
- gr.Examples(
108
- examples=[ # Tu devras ajouter tes propres exemples d'images
109
- "examples/t-shirt.jpg",
110
- "examples/dress.jpg",
111
- "examples/jacket.jpg"
112
- ],
113
- inputs=image_input
114
- )
115
 
116
- # Lance l'application
117
  if __name__ == "__main__":
118
- demo.launch(debug=True)
 
3
  from PIL import Image
4
  import numpy as np
5
  import cv2
6
+ from categories import FASHION_CATEGORIES
7
 
8
  # --- Configuration et Chargement des Modèles ---
 
9
  print("Loading segmentation model... This might take a minute.")
 
10
  seg_pipe = pipeline("image-segmentation", model="mattmdjaga/segformer_b2_clothes")
11
 
12
  print("Loading fashion classification model...")
13
+ # UTILISATION DU MODÈLE CLIP D'OPENAI - GARANTI PUBLIC
14
+ class_pipe = pipeline("zero-shot-image-classification", model="openai/clip-vit-base-patch32")
 
15
  print("Models loaded successfully!")
16
 
17
  # --- Fonctions de Prétraitement ---
18
  def get_largest_segment(segments):
19
+ """Trouve le segment le plus grand."""
20
  largest_area = 0
21
  largest_segment = None
22
  for segment in segments:
23
  mask = segment['mask']
24
+ area = np.sum(mask)
25
  if area > largest_area:
26
  largest_area = area
27
  largest_segment = segment
28
  return largest_segment
29
 
30
  def crop_to_mask(image, mask):
31
+ """Recadre l'image sur la zone du masque."""
 
 
 
32
  mask_np = np.array(mask)
 
33
  y, x = np.where(mask_np > 0)
34
  if len(x) == 0 or len(y) == 0:
35
+ return image
36
 
37
  x_min, x_max = np.min(x), np.max(x)
38
  y_min, y_max = np.min(y), np.max(y)
39
 
 
40
  cropped_image = image.crop((x_min, y_min, x_max, y_max))
41
  return cropped_image
42
 
43
+ # --- Fonction Principale ---
44
  def classify_image(input_image):
45
+ """Fonction appelée par Gradio."""
 
 
 
 
 
46
  pil_image = Image.fromarray(input_image)
47
 
48
  # ÉTAPE 1: SEGMENTATION
 
50
  main_item = get_largest_segment(segments)
51
 
52
  if main_item is None:
53
+ return "No clothing item detected. Please try another image.", None
54
 
55
+ # ÉTAPE 2: ISOLATION
56
  isolated_image = crop_to_mask(pil_image, main_item['mask'])
57
 
58
+ # ÉTAPE 3: CLASSIFICATION
 
59
  predictions = class_pipe(isolated_image, candidate_labels=FASHION_CATEGORIES)
60
 
61
+ # Formatage des résultats
62
  result_text = "Classification Results (on isolated garment):\n"
63
+ for pred in predictions[:5]: # Top 5 résultats seulement
 
64
  result_text += f"- {pred['label']}: {pred['score']*100:.2f}%\n"
65
 
 
66
  return result_text, isolated_image
67
 
68
  # --- Interface Gradio ---
 
69
  with gr.Blocks(title="Fashion Category Classifier") as demo:
70
  gr.Markdown("# 👗 Fashion Category Classifier")
71
  gr.Markdown("Upload a picture of a clothing item. The AI will isolate it and classify it.")
 
77
 
78
  with gr.Column():
79
  label_output = gr.Textbox(label="Classification Results", lines=6)
80
+ image_output = gr.Image(label="Isolated Garment", type="pil")
81
 
 
82
  classify_btn.click(
83
  fn=classify_image,
84
  inputs=image_input,
85
  outputs=[label_output, image_output]
86
  )
 
 
 
 
 
 
 
 
 
 
87
 
 
88
  if __name__ == "__main__":
89
+ demo.launch()