Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -67,112 +67,21 @@ def predict_next_move(fen, stockfish):
|
|
| 67 |
return "Invalid FEN notation!"
|
| 68 |
|
| 69 |
best_move = stockfish.get_best_move()
|
| 70 |
-
|
| 71 |
-
|
| 72 |
|
| 73 |
-
def process_image(image_path):
|
| 74 |
-
# Ensure output directory exists
|
| 75 |
-
if not os.path.exists('output'):
|
| 76 |
-
os.makedirs('output')
|
| 77 |
-
|
| 78 |
-
# Load the segmentation model
|
| 79 |
-
segmentation_model = YOLO("segmentation.pt")
|
| 80 |
-
|
| 81 |
-
# Run inference to get segmentation results
|
| 82 |
-
results = segmentation_model.predict(
|
| 83 |
-
source=image_path,
|
| 84 |
-
conf=0.8 # Confidence threshold
|
| 85 |
-
)
|
| 86 |
-
|
| 87 |
-
# Initialize variables for the segmented mask and bounding box
|
| 88 |
-
segmentation_mask = None
|
| 89 |
-
bbox = None
|
| 90 |
-
|
| 91 |
-
for result in results:
|
| 92 |
-
if result.boxes.conf[0] >= 0.8: # Filter results by confidence
|
| 93 |
-
segmentation_mask = result.masks.data.cpu().numpy().astype(np.uint8)[0]
|
| 94 |
-
bbox = result.boxes.xyxy[0].cpu().numpy() # Get the bounding box coordinates
|
| 95 |
-
break
|
| 96 |
-
|
| 97 |
-
if segmentation_mask is None:
|
| 98 |
-
print("No segmentation mask with confidence above 0.8 found.")
|
| 99 |
-
return None
|
| 100 |
-
|
| 101 |
-
# Load the image
|
| 102 |
-
image = cv2.imread(image_path)
|
| 103 |
-
|
| 104 |
-
# Convert the image to RGB format
|
| 105 |
-
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 106 |
-
|
| 107 |
-
# Resize segmentation mask to match the input image dimensions
|
| 108 |
-
segmentation_mask_resized = cv2.resize(segmentation_mask, (image.shape[1], image.shape[0]))
|
| 109 |
-
|
| 110 |
-
# Extract bounding box coordinates
|
| 111 |
-
if bbox is not None:
|
| 112 |
-
x1, y1, x2, y2 = bbox
|
| 113 |
-
# Crop the segmented region based on the bounding box
|
| 114 |
-
cropped_segment = image[int(y1):int(y2), int(x1):int(x2)]
|
| 115 |
-
|
| 116 |
-
# Convert the cropped segment to RGB
|
| 117 |
-
cropped_segment_rgb = cv2.cvtColor(cropped_segment, cv2.COLOR_BGR2RGB)
|
| 118 |
-
|
| 119 |
-
# Save the cropped segmented image
|
| 120 |
-
cropped_image_path = 'output/cropped_segment.jpg'
|
| 121 |
-
cv2.imwrite(cropped_image_path, cropped_segment)
|
| 122 |
-
print(f"Cropped segmented image saved to {cropped_image_path}")
|
| 123 |
-
|
| 124 |
-
# Display the image in Streamlit
|
| 125 |
-
st.image(cropped_segment_rgb, caption="Uploaded Image (Cropped)", use_column_width=True)
|
| 126 |
-
|
| 127 |
-
# Return the cropped RGB image
|
| 128 |
-
return cropped_segment_rgb
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
def transform_string(input_str):
|
| 132 |
-
# Remove extra spaces and convert to lowercase
|
| 133 |
-
input_str = input_str.strip().lower()
|
| 134 |
-
|
| 135 |
-
# Check if input is valid
|
| 136 |
-
if len(input_str) != 4 or not input_str[0].isalpha() or not input_str[1].isdigit() or \
|
| 137 |
-
not input_str[2].isalpha() or not input_str[3].isdigit():
|
| 138 |
-
return "Invalid input"
|
| 139 |
-
|
| 140 |
-
# Define mappings
|
| 141 |
-
letter_mapping = {
|
| 142 |
-
'a': 'h', 'b': 'g', 'c': 'f', 'd': 'e',
|
| 143 |
-
'e': 'd', 'f': 'c', 'g': 'b', 'h': 'a'
|
| 144 |
-
}
|
| 145 |
-
number_mapping = {
|
| 146 |
-
'1': '8', '2': '7', '3': '6', '4': '5',
|
| 147 |
-
'5': '4', '6': '3', '7': '2', '8': '1'
|
| 148 |
-
}
|
| 149 |
-
|
| 150 |
-
# Transform string
|
| 151 |
-
result = ""
|
| 152 |
-
for i, char in enumerate(input_str):
|
| 153 |
-
if i % 2 == 0: # Letters
|
| 154 |
-
result += letter_mapping.get(char, "Invalid")
|
| 155 |
-
else: # Numbers
|
| 156 |
-
result += number_mapping.get(char, "Invalid")
|
| 157 |
-
|
| 158 |
-
# Check for invalid transformations
|
| 159 |
-
if "Invalid" in result:
|
| 160 |
-
return "Invalid input"
|
| 161 |
-
|
| 162 |
-
return result
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
# Streamlit app
|
| 167 |
-
def main():
|
| 168 |
-
st.title("Chessboard Position Detection and Move Prediction")
|
| 169 |
-
|
| 170 |
-
os.chmod("/home/user/app/stockfish-ubuntu-x86-64-sse41-popcnt", 0o755)
|
| 171 |
|
| 172 |
-
|
| 173 |
|
| 174 |
|
| 175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
# User uploads an image or captures it from their camera
|
| 177 |
image_file = st.camera_input("Capture a chessboard image") or st.file_uploader("Upload a chessboard image", type=["jpg", "jpeg", "png"])
|
| 178 |
|
|
@@ -184,96 +93,84 @@ def main():
|
|
| 184 |
with open(temp_file_path, "wb") as f:
|
| 185 |
f.write(image_file.getbuffer())
|
| 186 |
|
| 187 |
-
#
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
)
|
| 269 |
-
|
| 270 |
-
# Predict the next move
|
| 271 |
-
next_move = predict_next_move(fen_notation, stockfish)
|
| 272 |
-
st.subheader("Stockfish Recommended Move:")
|
| 273 |
-
st.write(next_move)
|
| 274 |
-
|
| 275 |
-
else:
|
| 276 |
-
st.error("Failed to process the image. Please try again.")
|
| 277 |
-
|
| 278 |
-
if __name__ == "__main__":
|
| 279 |
-
main()
|
|
|
|
| 67 |
return "Invalid FEN notation!"
|
| 68 |
|
| 69 |
best_move = stockfish.get_best_move()
|
| 70 |
+
return f"The predicted next move is: {best_move}" if best_move else "No valid move found (checkmate/stalemate)."
|
| 71 |
+
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
|
|
|
| 74 |
|
| 75 |
|
| 76 |
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def main():
|
| 80 |
+
st.title("Chessboard Position Detection and Move Prediction")
|
| 81 |
+
|
| 82 |
+
# Set permissions for the Stockfish engine binary
|
| 83 |
+
os.chmod("/home/user/app/stockfish-ubuntu-x86-64-sse41-popcnt", 0o755)
|
| 84 |
+
|
| 85 |
# User uploads an image or captures it from their camera
|
| 86 |
image_file = st.camera_input("Capture a chessboard image") or st.file_uploader("Upload a chessboard image", type=["jpg", "jpeg", "png"])
|
| 87 |
|
|
|
|
| 93 |
with open(temp_file_path, "wb") as f:
|
| 94 |
f.write(image_file.getbuffer())
|
| 95 |
|
| 96 |
+
# Load the YOLO models
|
| 97 |
+
model = YOLO("fine_tuned_on_all_data.pt") # Replace with your trained model weights file
|
| 98 |
+
seg_model = YOLO("segmentation.pt")
|
| 99 |
+
|
| 100 |
+
# Load and process the image
|
| 101 |
+
img = cv2.imread(temp_file_path)
|
| 102 |
+
r = seg_model.predict(source=temp_file_path)
|
| 103 |
+
xyxy = r[0].boxes.xyxy
|
| 104 |
+
x_min, y_min, x_max, y_max = map(int, xyxy[0])
|
| 105 |
+
new_img = img[y_min:y_max, x_min:x_max]
|
| 106 |
+
|
| 107 |
+
# Resize the image to 224x224
|
| 108 |
+
image = cv2.resize(new_img, (224, 224))
|
| 109 |
+
height, width, _ = image.shape
|
| 110 |
+
|
| 111 |
+
# Get user input for perspective
|
| 112 |
+
p = st.radio("Select perspective:", ["b (Black)", "w (White)"])
|
| 113 |
+
p = p[0].lower()
|
| 114 |
+
|
| 115 |
+
# Initialize the board for FEN (empty rows represented by "8")
|
| 116 |
+
board = [["8"] * 8 for _ in range(8)]
|
| 117 |
+
|
| 118 |
+
# Run detection
|
| 119 |
+
results = model.predict(source=image, save=False, save_txt=False, conf=0.6)
|
| 120 |
+
|
| 121 |
+
# Extract predictions and map to FEN board
|
| 122 |
+
for result in results[0].boxes:
|
| 123 |
+
x1, y1, x2, y2 = result.xyxy[0].tolist()
|
| 124 |
+
class_id = int(result.cls[0])
|
| 125 |
+
class_name = model.names[class_id]
|
| 126 |
+
|
| 127 |
+
fen_piece = FEN_MAPPING.get(class_name, None)
|
| 128 |
+
if not fen_piece:
|
| 129 |
+
continue
|
| 130 |
+
|
| 131 |
+
center_x = (x1 + x2) / 2
|
| 132 |
+
center_y = (y1 + y2) / 2
|
| 133 |
+
pixel_x = int(center_x)
|
| 134 |
+
pixel_y = int(height - center_y)
|
| 135 |
+
|
| 136 |
+
grid_position = get_grid_coordinate(pixel_x, pixel_y, p)
|
| 137 |
+
if grid_position != "Pixel outside grid bounds":
|
| 138 |
+
file = ord(grid_position[0]) - ord('a')
|
| 139 |
+
rank = int(grid_position[1]) - 1
|
| 140 |
+
board[rank][file] = fen_piece
|
| 141 |
+
|
| 142 |
+
# Generate the FEN string
|
| 143 |
+
fen_rows = []
|
| 144 |
+
for row in board:
|
| 145 |
+
fen_row = ""
|
| 146 |
+
empty_count = 0
|
| 147 |
+
for cell in row:
|
| 148 |
+
if cell == "8":
|
| 149 |
+
empty_count += 1
|
| 150 |
+
else:
|
| 151 |
+
if empty_count > 0:
|
| 152 |
+
fen_row += str(empty_count)
|
| 153 |
+
empty_count = 0
|
| 154 |
+
fen_row += cell
|
| 155 |
+
if empty_count > 0:
|
| 156 |
+
fen_row += str(empty_count)
|
| 157 |
+
fen_rows.append(fen_row)
|
| 158 |
+
|
| 159 |
+
position_fen = "/".join(fen_rows)
|
| 160 |
+
move_side = st.radio("Select the side to move:", ["w (White)", "b (Black)"])[0].lower()
|
| 161 |
+
fen_notation = f"{position_fen} {move_side} - - 0 0"
|
| 162 |
+
st.subheader("Generated FEN Notation:")
|
| 163 |
+
st.code(fen_notation)
|
| 164 |
+
|
| 165 |
+
# Initialize the Stockfish engine
|
| 166 |
+
stockfish_path = os.path.join(os.getcwd(), "stockfish-ubuntu-x86-64-sse41-popcnt")
|
| 167 |
+
stockfish = Stockfish(
|
| 168 |
+
path=stockfish_path,
|
| 169 |
+
depth=15,
|
| 170 |
+
parameters={"Threads": 2, "Minimum Thinking Time": 30}
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# Predict the next move
|
| 174 |
+
next_move = predict_next_move(fen_notation, stockfish)
|
| 175 |
+
st.subheader("Stockfish Recommended Move:")
|
| 176 |
+
st.write(next_move)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|