StoneSeller commited on
Commit
7545949
·
verified ·
1 Parent(s): 6787f69

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -16
app.py CHANGED
@@ -71,11 +71,10 @@ except Exception as e:
71
  print(f"Error loading model: {str(e)}")
72
  traceback.print_exc()
73
 
74
- # Define image transformation pipeline to match training
75
  transform = transforms.Compose([
76
  transforms.Resize((128, 128)),
 
77
  transforms.ToTensor(),
78
- # Using standard normalization as in training
79
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
80
  ])
81
 
@@ -84,15 +83,17 @@ def process_image(image):
84
  return None
85
 
86
  try:
87
- # Convert numpy array to PIL Image
88
  if isinstance(image, np.ndarray):
89
- image = Image.fromarray(image)
90
 
91
- # Convert to RGB if necessary
92
  if image.mode != 'RGB':
93
  image = image.convert('RGB')
94
-
95
- # Print debug information
 
 
96
  print(f"Processed image size: {image.size}")
97
  print(f"Processed image mode: {image.mode}")
98
 
@@ -107,16 +108,32 @@ def predict(image):
107
  return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
108
 
109
  try:
110
- # Process the image
111
  processed_image = process_image(image)
112
  if processed_image is None:
113
  return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
114
 
115
- # Transform for model
116
- tensor_image = transform(processed_image).unsqueeze(0)
117
- print(f"Input tensor shape: {tensor_image.shape}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- # Make prediction
120
  with torch.no_grad():
121
  outputs = model(tensor_image)
122
  print(f"Raw outputs: {outputs}")
@@ -124,7 +141,7 @@ def predict(image):
124
  probabilities = F.softmax(outputs, dim=1)[0].cpu().numpy()
125
  print(f"Probabilities: {probabilities}")
126
 
127
- # Return results
128
  classes = ["Rope", "Hammer", "Other"]
129
  results = {cls: float(prob) for cls, prob in zip(classes, probabilities)}
130
  print(f"Final results: {results}")
@@ -135,15 +152,15 @@ def predict(image):
135
  traceback.print_exc()
136
  return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
137
 
138
- # Gradio interface
139
  interface = gr.Interface(
140
  fn=predict,
141
- inputs=gr.Image(), # Accept any image format
142
  outputs=gr.Label(num_top_classes=3),
143
  title="Mechanical Tools Classifier",
144
  description="Upload an image of a tool to classify it as 'Rope', 'Hammer', or 'Other'.",
145
  )
146
 
147
- # Launch the interface
148
  if __name__ == "__main__":
149
  interface.launch()
 
71
  print(f"Error loading model: {str(e)}")
72
  traceback.print_exc()
73
 
 
74
  transform = transforms.Compose([
75
  transforms.Resize((128, 128)),
76
+ transforms.Lambda(lambda x: x.convert('RGB')),
77
  transforms.ToTensor(),
 
78
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
79
  ])
80
 
 
83
  return None
84
 
85
  try:
86
+ # numpy array PIL Image로 변환
87
  if isinstance(image, np.ndarray):
88
+ image = Image.fromarray(image.astype('uint8'))
89
 
90
+ # 이미지가 RGB 아니면 변환
91
  if image.mode != 'RGB':
92
  image = image.convert('RGB')
93
+
94
+ # 이미지 크기 조정
95
+ image = image.resize((128, 128), Image.Resampling.LANCZOS)
96
+
97
  print(f"Processed image size: {image.size}")
98
  print(f"Processed image mode: {image.mode}")
99
 
 
108
  return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
109
 
110
  try:
111
+ # 이미지 처리
112
  processed_image = process_image(image)
113
  if processed_image is None:
114
  return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
115
 
116
+ # PIL Image를 텐서로 변환
117
+ try:
118
+ # PIL Image를 numpy array로 변환
119
+ img_array = np.array(processed_image)
120
+ # numpy array를 torch tensor로 변환
121
+ tensor_image = torch.from_numpy(img_array.transpose((2, 0, 1))).float() / 255.0
122
+ # 정규화
123
+ tensor_image = transforms.Normalize(
124
+ mean=[0.485, 0.456, 0.406],
125
+ std=[0.229, 0.224, 0.225]
126
+ )(tensor_image)
127
+ # 배치 차원 추가
128
+ tensor_image = tensor_image.unsqueeze(0)
129
+
130
+ print(f"Input tensor shape: {tensor_image.shape}")
131
+ except Exception as e:
132
+ print(f"Error in tensor conversion: {str(e)}")
133
+ traceback.print_exc()
134
+ return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
135
 
136
+ # 예측 수행
137
  with torch.no_grad():
138
  outputs = model(tensor_image)
139
  print(f"Raw outputs: {outputs}")
 
141
  probabilities = F.softmax(outputs, dim=1)[0].cpu().numpy()
142
  print(f"Probabilities: {probabilities}")
143
 
144
+ # 결과 반환
145
  classes = ["Rope", "Hammer", "Other"]
146
  results = {cls: float(prob) for cls, prob in zip(classes, probabilities)}
147
  print(f"Final results: {results}")
 
152
  traceback.print_exc()
153
  return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
154
 
155
+ # Gradio 인터페이스
156
  interface = gr.Interface(
157
  fn=predict,
158
+ inputs=gr.Image(),
159
  outputs=gr.Label(num_top_classes=3),
160
  title="Mechanical Tools Classifier",
161
  description="Upload an image of a tool to classify it as 'Rope', 'Hammer', or 'Other'.",
162
  )
163
 
164
+ # 인터페이스 실행
165
  if __name__ == "__main__":
166
  interface.launch()