StoneSeller commited on
Commit
b8cf886
·
verified ·
1 Parent(s): 6cc530c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -52
app.py CHANGED
@@ -2,35 +2,50 @@ import subprocess
2
  import sys
3
  import os
4
 
5
- # Function to install or reinstall specific packages
6
- def install(package):
7
- subprocess.check_call([sys.executable, "-m", "pip", "install", "--force-reinstall", package])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- # First, ensure NumPy is installed with the correct version
10
- try:
11
- import numpy as np
12
- if not np.__version__.startswith("1.24"):
13
- print("Installing compatible NumPy version...")
14
- install("numpy==1.24.3")
15
- except ImportError:
16
- print("NumPy not found. Installing...")
17
- install("numpy==1.24.3")
18
-
19
- # Then install other dependencies
20
- packages = {
21
- "torch": "2.0.1",
22
- "torchvision": "0.15.2",
23
- "Pillow": "9.5.0",
24
- "gradio": "3.50.2"
25
- }
26
-
27
- for package, version in packages.items():
28
- try:
29
- __import__(package.lower())
30
- except ImportError:
31
- print(f"Installing {package}...")
32
- install(f"{package}=={version}")
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  import traceback
36
  import numpy as np
@@ -77,23 +92,39 @@ transform = transforms.Compose([
77
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
78
  ])
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  def process_image(image):
81
  if image is None:
82
  return None
83
 
84
  try:
85
- # Convert numpy array to PIL Image
86
  if isinstance(image, np.ndarray):
87
- # Ensure the array is uint8
88
- if image.dtype != np.uint8:
89
- image = (image * 255).astype(np.uint8)
90
- image = Image.fromarray(image)
91
 
92
- # Convert to RGB if necessary
93
  if image.mode != 'RGB':
94
  image = image.convert('RGB')
95
 
96
- # Resize the image
97
  image = image.resize((128, 128), Image.Resampling.LANCZOS)
98
 
99
  print(f"Processed image size: {image.size}")
@@ -111,39 +142,45 @@ def predict(image):
111
  return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
112
 
113
  try:
114
- # Process the image
115
  processed_image = process_image(image)
116
  if processed_image is None:
117
  return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
118
 
119
- # Transform image to tensor
120
  try:
121
- # Convert PIL Image to tensor
122
- tensor_image = transform(processed_image)
123
- # Add batch dimension
124
- tensor_image = tensor_image.unsqueeze(0)
125
  print(f"Input tensor shape: {tensor_image.shape}")
126
  print(f"Tensor dtype: {tensor_image.dtype}")
127
  print(f"Tensor device: {tensor_image.device}")
 
128
  except Exception as e:
129
  print(f"Error in tensor conversion: {str(e)}")
130
  traceback.print_exc()
131
  return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
132
 
133
- # Make prediction
134
- with torch.no_grad():
135
- outputs = model(tensor_image)
136
- print(f"Raw outputs: {outputs}")
 
 
 
 
137
 
138
- probabilities = F.softmax(outputs, dim=1)[0].cpu().numpy()
139
- print(f"Probabilities: {probabilities}")
 
 
 
 
 
 
 
 
140
 
141
- # Return results
142
- classes = ["Rope", "Hammer", "Other"]
143
- results = {cls: float(prob) for cls, prob in zip(classes, probabilities)}
144
- print(f"Final results: {results}")
145
- return results
146
-
147
  except Exception as e:
148
  print(f"Prediction error: {str(e)}")
149
  traceback.print_exc()
 
2
  import sys
3
  import os
4
 
5
+ # # Function to install or reinstall specific packages
6
+ # def install(package):
7
+ # subprocess.check_call([sys.executable, "-m", "pip", "install", "--force-reinstall", package])
8
+
9
+ # # First, ensure NumPy is installed with the correct version
10
+ # try:
11
+ # import numpy as np
12
+ # if not np.__version__.startswith("1.24"):
13
+ # print("Installing compatible NumPy version...")
14
+ # install("numpy==1.24.3")
15
+ # except ImportError:
16
+ # print("NumPy not found. Installing...")
17
+ # install("numpy==1.24.3")
18
+
19
+ # # Then install other dependencies
20
+ # packages = {
21
+ # "torch": "2.0.1",
22
+ # "torchvision": "0.15.2",
23
+ # "Pillow": "9.5.0",
24
+ # "gradio": "3.50.2"
25
+ # }
26
+
27
+ # for package, version in packages.items():
28
+ # try:
29
+ # __import__(package.lower())
30
+ # except ImportError:
31
+ # print(f"Installing {package}...")
32
+ # install(f"{package}=={version}")
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ # 먼저 필요한 패키지들을 순서대로 설치
36
+ def install_requirements():
37
+ packages = [
38
+ "numpy==1.24.3",
39
+ "torch==2.0.1",
40
+ "torchvision==0.15.2",
41
+ "Pillow==9.5.0",
42
+ "gradio==3.50.2"
43
+ ]
44
+ for package in packages:
45
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "--force-reinstall", package])
46
+
47
+ # 패키지 설치 실행
48
+ install_requirements()
49
 
50
  import traceback
51
  import numpy as np
 
92
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
93
  ])
94
 
95
+ def custom_transform(pil_image):
96
+ # PIL Image를 numpy array로 변환
97
+ np_image = np.array(pil_image)
98
+
99
+ # numpy array를 torch tensor로 변환 (채널 순서 변경 포함)
100
+ tensor_image = torch.from_numpy(np_image.transpose((2, 0, 1))).float()
101
+
102
+ # 값 범위를 [0, 1]로 정규화
103
+ tensor_image = tensor_image / 255.0
104
+
105
+ # ImageNet 정규화 적용
106
+ normalize = transforms.Normalize(
107
+ mean=[0.485, 0.456, 0.406],
108
+ std=[0.229, 0.224, 0.225]
109
+ )
110
+ tensor_image = normalize(tensor_image)
111
+
112
+ return tensor_image
113
+
114
  def process_image(image):
115
  if image is None:
116
  return None
117
 
118
  try:
119
+ # numpy array PIL Image로 변환
120
  if isinstance(image, np.ndarray):
121
+ image = Image.fromarray(image.astype('uint8'))
 
 
 
122
 
123
+ # RGB 변환
124
  if image.mode != 'RGB':
125
  image = image.convert('RGB')
126
 
127
+ # 크기 조정
128
  image = image.resize((128, 128), Image.Resampling.LANCZOS)
129
 
130
  print(f"Processed image size: {image.size}")
 
142
  return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
143
 
144
  try:
145
+ # 이미지 전처리
146
  processed_image = process_image(image)
147
  if processed_image is None:
148
  return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
149
 
 
150
  try:
151
+ # 커스텀 변환 함수를 사용하여 텐서로 변환
152
+ tensor_image = custom_transform(processed_image)
153
+ tensor_image = tensor_image.unsqueeze(0) # 배치 차원 추가
154
+
155
  print(f"Input tensor shape: {tensor_image.shape}")
156
  print(f"Tensor dtype: {tensor_image.dtype}")
157
  print(f"Tensor device: {tensor_image.device}")
158
+
159
  except Exception as e:
160
  print(f"Error in tensor conversion: {str(e)}")
161
  traceback.print_exc()
162
  return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
163
 
164
+ # 예측 수행
165
+ try:
166
+ with torch.no_grad():
167
+ outputs = model(tensor_image)
168
+ print(f"Raw outputs: {outputs}")
169
+
170
+ probabilities = F.softmax(outputs, dim=1)[0].cpu().numpy()
171
+ print(f"Probabilities: {probabilities}")
172
 
173
+ # 결과 반환
174
+ classes = ["Rope", "Hammer", "Other"]
175
+ results = {cls: float(prob) for cls, prob in zip(classes, probabilities)}
176
+ print(f"Final results: {results}")
177
+ return results
178
+
179
+ except Exception as e:
180
+ print(f"Error in prediction: {str(e)}")
181
+ traceback.print_exc()
182
+ return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
183
 
 
 
 
 
 
 
184
  except Exception as e:
185
  print(f"Prediction error: {str(e)}")
186
  traceback.print_exc()