commit
Browse files- .gitignore +5 -0
- app.py +127 -16
- chess_board_tool.py +252 -0
- chess_pieces_detection/__init__.py +1 -0
- chess_pieces_detection/train.sh +2 -0
- chess_pieces_detection/train_chess_pieces_recognition.py +400 -0
- install-requirements.sh +2 -0
- my_prompt_config.py +57 -0
- my_tools.py +66 -0
- requirements.txt +9 -1
- run.sh +2 -0
- simple.py +87 -0
- simple.sh +3 -0
- test_tools.py +34 -0
.gitignore
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/.env
|
| 2 |
+
/venvw/
|
| 3 |
+
/questions.json
|
| 4 |
+
/venv/
|
| 5 |
+
/.idea/
|
app.py
CHANGED
|
@@ -3,27 +3,121 @@ import gradio as gr
|
|
| 3 |
import requests
|
| 4 |
import inspect
|
| 5 |
import pandas as pd
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
# (Keep Constants as is)
|
| 8 |
# --- Constants ---
|
| 9 |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
# --- Basic Agent Definition ---
|
| 12 |
-
# ----- THIS IS
|
| 13 |
class BasicAgent:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
def __init__(self):
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
def __call__(self, question: str) -> str:
|
| 17 |
print(f"Agent received question (first 50 chars): {question[:50]}...")
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
def run_and_submit_all( profile: gr.OAuthProfile | None):
|
| 23 |
"""
|
| 24 |
Fetches all questions, runs the BasicAgent on them, submits all answers,
|
| 25 |
and displays the results.
|
| 26 |
"""
|
|
|
|
| 27 |
# --- Determine HF Space Runtime URL and Repo URL ---
|
| 28 |
space_id = os.getenv("SPACE_ID") # Get the SPACE_ID for sending link to the code
|
| 29 |
|
|
@@ -76,10 +170,23 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
|
|
| 76 |
for item in questions_data:
|
| 77 |
task_id = item.get("task_id")
|
| 78 |
question_text = item.get("question")
|
|
|
|
|
|
|
| 79 |
if not task_id or question_text is None:
|
| 80 |
print(f"Skipping item with missing task_id or question: {item}")
|
| 81 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
submitted_answer = agent(question_text)
|
| 84 |
answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
|
| 85 |
results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
|
|
@@ -99,17 +206,21 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
|
|
| 99 |
# 5. Submit
|
| 100 |
print(f"Submitting {len(answers_payload)} answers to: {submit_url}")
|
| 101 |
try:
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
results_df = pd.DataFrame(results_log)
|
| 114 |
return final_status, results_df
|
| 115 |
except requests.exceptions.HTTPError as e:
|
|
|
|
| 3 |
import requests
|
| 4 |
import inspect
|
| 5 |
import pandas as pd
|
| 6 |
+
from smolagents import CodeAgent, tool, InferenceClientModel, WebSearchTool, load_tool, PromptTemplates, Tool, FinalAnswerTool
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
from my_prompt_config import PromptConfig
|
| 9 |
+
from my_tools import ReverseStringTool, ImageLoadTool
|
| 10 |
+
from PIL import Image
|
| 11 |
|
| 12 |
# (Keep Constants as is)
|
| 13 |
# --- Constants ---
|
| 14 |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
| 15 |
|
| 16 |
+
# testing --------------------------------------------
|
| 17 |
+
testing_mode = True
|
| 18 |
+
questions_to_run = [
|
| 19 |
+
#"8e867cd7-cff9-4e6c-867a-ff5ddc2550be", # OK
|
| 20 |
+
#"a1e91b78-d3d8-4675-bb8d-62741b4b68a6"
|
| 21 |
+
#"2d83110e-a098-4ebb-9987-066c06fa42d0" # almost OK
|
| 22 |
+
"cca530fc-4052-43b2-b130-b30968d8aa44"
|
| 23 |
+
#"4fc2f1ae-8625-45b5-ab34-ad4433bc21f8"
|
| 24 |
+
#"6f37996b-2ac7-44b0-8e68-6d28256631b4"
|
| 25 |
+
#"9d191bce-651d-4746-be2d-7ef8ecadb9c2"
|
| 26 |
+
#"cabe07ed-9eca-40ea-8ead-410ef5e83f91"
|
| 27 |
+
#"3cef3a44-215e-4aed-8e3b-b1e3f08063b7"
|
| 28 |
+
#"99c9cc74-fdc8-46c6-8f8d-3ce2d3bfeea3"
|
| 29 |
+
#"305ac316-eef6-4446-960a-92d80d542f82"
|
| 30 |
+
#"f918266a-b3e0-4914-865d-4faa564f1aef"
|
| 31 |
+
#"3f57289b-8c60-48be-bd80-01f8099ca449"
|
| 32 |
+
#"1f975693-876d-457b-a649-393859e79bf3"
|
| 33 |
+
#"840bfca7-4f7b-481a-8794-c560c340185d"
|
| 34 |
+
#"bda648d7-d618-4883-88f4-3466eabd860e"
|
| 35 |
+
#"cf106601-ab4f-4af9-b045-5295fe67b37d"
|
| 36 |
+
#"a0c07678-e491-4bbc-8f0b-07405144218f"
|
| 37 |
+
#"7bd855d8-463d-4ed5-93ca-5fe35145f733"
|
| 38 |
+
#"5a0c1adf-205e-4841-a666-7c3ef95def9d"
|
| 39 |
+
]
|
| 40 |
+
# testing --------------------------------------------
|
| 41 |
+
|
| 42 |
# --- Basic Agent Definition ---
|
| 43 |
+
# ----- THIS IS WHERE YOU CAN BUILD WHAT YOU WANT ------
|
| 44 |
class BasicAgent:
|
| 45 |
+
#MODEL_CODER = "Qwen/Qwen2.5-Coder-32B-Instruct"
|
| 46 |
+
MODEL_CODER = "Qwen/Qwen2.5-72B-Instruct"
|
| 47 |
+
#MODEL_REASONING = "deepseek-ai/DeepSeek-R1"
|
| 48 |
+
MODEL_REASONING = "Qwen/Qwen2.5-72B-Instruct"
|
| 49 |
+
|
| 50 |
def __init__(self):
|
| 51 |
+
load_dotenv()
|
| 52 |
+
print("Agent initialized.")
|
| 53 |
+
self.__create_agents__()
|
| 54 |
+
|
| 55 |
+
def __create_agents__(self):
|
| 56 |
+
web_search_agent = CodeAgent(
|
| 57 |
+
tools=[WebSearchTool()],
|
| 58 |
+
model=InferenceClientModel(model_id=self.MODEL_CODER),
|
| 59 |
+
name="agent_websearch",
|
| 60 |
+
description="Agent to browse and search and extract web content"
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
#reverse_agent = CodeAgent(
|
| 64 |
+
# tools=[ReverseStringTool()],
|
| 65 |
+
# model=InferenceClientModel(model_id=self.MODEL_CODER),
|
| 66 |
+
# name="agent_reverse",
|
| 67 |
+
# description="Agent to reverse strings"
|
| 68 |
+
#)
|
| 69 |
+
|
| 70 |
+
# self.image_generation_tool = load_tool("m-ric/text-to-image", trust_remote_code=True)
|
| 71 |
+
image_generation_tool = Tool.from_space(
|
| 72 |
+
"black-forest-labs/FLUX.1-schnell",
|
| 73 |
+
name="image_generator",
|
| 74 |
+
description="Generate an image from a prompt"
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
image_captioning_tool = Tool.from_space(
|
| 78 |
+
"ovi054/image-to-prompt",
|
| 79 |
+
name="image_captioning",
|
| 80 |
+
description="Generate description of an image"
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
image_loading_tool = ImageLoadTool()
|
| 84 |
+
print(f"Image load tool: {image_loading_tool}")
|
| 85 |
+
|
| 86 |
+
#image_generation_agent = CodeAgent(
|
| 87 |
+
# tools=[image_generation_tool],
|
| 88 |
+
# model=InferenceClientModel(model_id=self.MODEL_CODER)
|
| 89 |
+
#)
|
| 90 |
+
|
| 91 |
+
# ImageLoadTool()
|
| 92 |
+
self.reasoning_agent = CodeAgent(
|
| 93 |
+
#tools=[image_generation_tool, image_captioning_tool, ReverseStringTool(), image_loading_tool],
|
| 94 |
+
tools=[image_loading_tool, FinalAnswerTool()],
|
| 95 |
+
model=InferenceClientModel(model_id=self.MODEL_REASONING),
|
| 96 |
+
planning_interval=3, # This is where you activate planning!,
|
| 97 |
+
prompt_templates=PromptConfig().PROMPT_TEMPLATES,
|
| 98 |
+
managed_agents=[web_search_agent],
|
| 99 |
+
additional_authorized_imports=["PIL","chess","my_tools","my_tools."],
|
| 100 |
+
)
|
| 101 |
+
print(f"Main agent initialized: {self.reasoning_agent}")
|
| 102 |
+
|
| 103 |
def __call__(self, question: str) -> str:
|
| 104 |
print(f"Agent received question (first 50 chars): {question[:50]}...")
|
| 105 |
+
|
| 106 |
+
answer = self.reasoning_agent.run(question)
|
| 107 |
+
|
| 108 |
+
print(f"Agent returning answer: {answer}")
|
| 109 |
+
return answer
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# -----------------------------------------------------------------------------
|
| 113 |
+
#
|
| 114 |
|
| 115 |
def run_and_submit_all( profile: gr.OAuthProfile | None):
|
| 116 |
"""
|
| 117 |
Fetches all questions, runs the BasicAgent on them, submits all answers,
|
| 118 |
and displays the results.
|
| 119 |
"""
|
| 120 |
+
|
| 121 |
# --- Determine HF Space Runtime URL and Repo URL ---
|
| 122 |
space_id = os.getenv("SPACE_ID") # Get the SPACE_ID for sending link to the code
|
| 123 |
|
|
|
|
| 170 |
for item in questions_data:
|
| 171 |
task_id = item.get("task_id")
|
| 172 |
question_text = item.get("question")
|
| 173 |
+
question_file_name = item.get("file_name")
|
| 174 |
+
|
| 175 |
if not task_id or question_text is None:
|
| 176 |
print(f"Skipping item with missing task_id or question: {item}")
|
| 177 |
continue
|
| 178 |
+
|
| 179 |
+
if testing_mode:
|
| 180 |
+
if task_id not in questions_to_run:
|
| 181 |
+
continue
|
| 182 |
+
|
| 183 |
try:
|
| 184 |
+
if question_file_name is not None:
|
| 185 |
+
ext = question_file_name[-4:]
|
| 186 |
+
if ext == ".png":
|
| 187 |
+
question_text = question_text + (f" . Use available tool to load an image associated with task id: "
|
| 188 |
+
f"{task_id}")
|
| 189 |
+
|
| 190 |
submitted_answer = agent(question_text)
|
| 191 |
answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
|
| 192 |
results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
|
|
|
|
| 206 |
# 5. Submit
|
| 207 |
print(f"Submitting {len(answers_payload)} answers to: {submit_url}")
|
| 208 |
try:
|
| 209 |
+
if not testing_mode:
|
| 210 |
+
response = requests.post(submit_url, json=submission_data, timeout=60)
|
| 211 |
+
response.raise_for_status()
|
| 212 |
+
result_data = response.json()
|
| 213 |
+
final_status = (
|
| 214 |
+
f"Submission Successful!\n"
|
| 215 |
+
f"User: {result_data.get('username')}\n"
|
| 216 |
+
f"Overall Score: {result_data.get('score', 'N/A')}% "
|
| 217 |
+
f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
|
| 218 |
+
f"Message: {result_data.get('message', 'No message received.')}"
|
| 219 |
+
)
|
| 220 |
+
print("Submission successful.")
|
| 221 |
+
else:
|
| 222 |
+
final_status = "TESTING, Submission skipped"
|
| 223 |
+
|
| 224 |
results_df = pd.DataFrame(results_log)
|
| 225 |
return final_status, results_df
|
| 226 |
except requests.exceptions.HTTPError as e:
|
chess_board_tool.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from smolagents import Tool
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import os
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
import math
|
| 7 |
+
import numpy
|
| 8 |
+
from .chess_pieces_detection import ChessPiecesRecognition
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ChessBoard(Tool):
|
| 12 |
+
name = "_my_chess_board"
|
| 13 |
+
description = """
|
| 14 |
+
Analyze an image representing a chess board and extract board state in FEN notation
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
inputs = {
|
| 18 |
+
"img": {
|
| 19 |
+
"type": "image",
|
| 20 |
+
"description": "image of chess board to extract board position",
|
| 21 |
+
}
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
output_type = "string"
|
| 25 |
+
|
| 26 |
+
# Steps to do
|
| 27 |
+
# - break board image into array of images representing pieces
|
| 28 |
+
# Image -> Image[]
|
| 29 |
+
# - image recognition on set of images to get piece labels
|
| 30 |
+
# Image[] -> str[]
|
| 31 |
+
# - construct FEN from string of chess pieces
|
| 32 |
+
# str[] -> []
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def gradientx(self, img):
|
| 36 |
+
# Compute gradient in x-direction using larger Sobel kernel
|
| 37 |
+
grad_x = cv2.Sobel(img, cv2.CV_32F, 1, 0, ksize=31)
|
| 38 |
+
return grad_x
|
| 39 |
+
|
| 40 |
+
def gradienty(self, img):
|
| 41 |
+
# Compute gradient in y-direction using larger Sobel kernel
|
| 42 |
+
grad_y = cv2.Sobel(img, cv2.CV_32F, 0, 1, ksize=31)
|
| 43 |
+
return grad_y
|
| 44 |
+
|
| 45 |
+
def checkMatch(self, lineset):
|
| 46 |
+
linediff = np.diff(lineset)
|
| 47 |
+
x = 0
|
| 48 |
+
cnt = 0
|
| 49 |
+
for line in linediff:
|
| 50 |
+
if abs(line - x) < 5:
|
| 51 |
+
cnt += 1
|
| 52 |
+
else:
|
| 53 |
+
cnt = 0
|
| 54 |
+
x = line
|
| 55 |
+
return cnt == 5
|
| 56 |
+
|
| 57 |
+
def pruneLines(self, lineset, image_dim, margin=20):
|
| 58 |
+
# Remove lines near the margins
|
| 59 |
+
lineset = [x for x in lineset if x > margin and x < image_dim - margin]
|
| 60 |
+
if not lineset:
|
| 61 |
+
return lineset
|
| 62 |
+
linediff = np.diff(lineset)
|
| 63 |
+
x = 0
|
| 64 |
+
cnt = 0
|
| 65 |
+
start_pos = 0
|
| 66 |
+
for i, line in enumerate(linediff):
|
| 67 |
+
if abs(line - x) < 5:
|
| 68 |
+
cnt += 1
|
| 69 |
+
if cnt == 5:
|
| 70 |
+
end_pos = i + 2
|
| 71 |
+
return lineset[start_pos:end_pos]
|
| 72 |
+
else:
|
| 73 |
+
cnt = 0
|
| 74 |
+
x = line
|
| 75 |
+
start_pos = i
|
| 76 |
+
return lineset
|
| 77 |
+
|
| 78 |
+
def skeletonize_1d(self, arr):
|
| 79 |
+
_arr = arr.copy()
|
| 80 |
+
for i in range(len(_arr) - 1):
|
| 81 |
+
if _arr[i] <= _arr[i + 1]:
|
| 82 |
+
_arr[i] = 0
|
| 83 |
+
for i in range(len(_arr) - 1, 0, -1):
|
| 84 |
+
if _arr[i - 1] > _arr[i]:
|
| 85 |
+
_arr[i] = 0
|
| 86 |
+
return _arr
|
| 87 |
+
|
| 88 |
+
def getChessLines(self, hdx, hdy, hdx_thresh, hdy_thresh, image_shape):
|
| 89 |
+
# Generate Gaussian window
|
| 90 |
+
window_size = 21
|
| 91 |
+
sigma = 8.0
|
| 92 |
+
gausswin = cv2.getGaussianKernel(window_size, sigma, cv2.CV_64F)
|
| 93 |
+
gausswin = gausswin.flatten()
|
| 94 |
+
half_size = window_size // 2
|
| 95 |
+
|
| 96 |
+
# Threshold signals
|
| 97 |
+
hdx_thresh_binary = np.where(hdx > hdx_thresh, 1.0, 0.0)
|
| 98 |
+
hdy_thresh_binary = np.where(hdy > hdy_thresh, 1.0, 0.0)
|
| 99 |
+
|
| 100 |
+
# Blur signals using convolution with Gaussian window
|
| 101 |
+
blur_x = np.convolve(hdx_thresh_binary, gausswin, mode='same')
|
| 102 |
+
blur_y = np.convolve(hdy_thresh_binary, gausswin, mode='same')
|
| 103 |
+
|
| 104 |
+
# Skeletonize signals
|
| 105 |
+
skel_x = self.skeletonize_1d(blur_x)
|
| 106 |
+
skel_y = self.skeletonize_1d(blur_y)
|
| 107 |
+
|
| 108 |
+
# Find line positions
|
| 109 |
+
lines_x = np.where(skel_x > 0)[0].tolist()
|
| 110 |
+
lines_y = np.where(skel_y > 0)[0].tolist()
|
| 111 |
+
|
| 112 |
+
# Prune lines
|
| 113 |
+
lines_x = self.pruneLines(lines_x, image_shape[1])
|
| 114 |
+
lines_y = self.pruneLines(lines_y, image_shape[0])
|
| 115 |
+
|
| 116 |
+
# Check if lines match expected pattern
|
| 117 |
+
is_match = (len(lines_x) == 7) and (len(lines_y) == 7) and \
|
| 118 |
+
self.checkMatch(lines_x) and self.checkMatch(lines_y)
|
| 119 |
+
|
| 120 |
+
return lines_x, lines_y, is_match
|
| 121 |
+
|
| 122 |
+
def getChessTiles(self, img, lines_x, lines_y):
|
| 123 |
+
stepx = int(round(np.mean(np.diff(lines_x))))
|
| 124 |
+
stepy = int(round(np.mean(np.diff(lines_y))))
|
| 125 |
+
|
| 126 |
+
# Pad the image if necessary
|
| 127 |
+
padl_x = 0
|
| 128 |
+
padr_x = 0
|
| 129 |
+
padl_y = 0
|
| 130 |
+
padr_y = 0
|
| 131 |
+
if lines_x[0] - stepx < 0:
|
| 132 |
+
padl_x = abs(lines_x[0] - stepx)
|
| 133 |
+
if lines_x[-1] + stepx > img.shape[1] - 1:
|
| 134 |
+
padr_x = lines_x[-1] + stepx - img.shape[1] + 1
|
| 135 |
+
if lines_y[0] - stepy < 0:
|
| 136 |
+
padl_y = abs(lines_y[0] - stepy)
|
| 137 |
+
if lines_y[-1] + stepy > img.shape[0] - 1:
|
| 138 |
+
padr_y = lines_y[-1] + stepy - img.shape[0] + 1
|
| 139 |
+
|
| 140 |
+
img_padded = cv2.copyMakeBorder(img, padl_y, padr_y, padl_x, padr_x, cv2.BORDER_REPLICATE)
|
| 141 |
+
|
| 142 |
+
setsx = [lines_x[0] - stepx + padl_x] + [x + padl_x for x in lines_x] + [lines_x[-1] + stepx + padl_x]
|
| 143 |
+
setsy = [lines_y[0] - stepy + padl_y] + [y + padl_y for y in lines_y] + [lines_y[-1] + stepy + padl_y]
|
| 144 |
+
|
| 145 |
+
squares = []
|
| 146 |
+
for j in range(8):
|
| 147 |
+
for i in range(8):
|
| 148 |
+
x1 = setsx[i]
|
| 149 |
+
x2 = setsx[i + 1]
|
| 150 |
+
y1 = setsy[j]
|
| 151 |
+
y2 = setsy[j + 1]
|
| 152 |
+
# Adjust sizes to ensure squares are of equal size
|
| 153 |
+
if (x2 - x1) != stepx:
|
| 154 |
+
x2 = x1 + stepx
|
| 155 |
+
if (y2 - y1) != stepy:
|
| 156 |
+
y2 = y1 + stepy
|
| 157 |
+
square = img_padded[y1:y2, x1:x2]
|
| 158 |
+
squares.append(square)
|
| 159 |
+
return squares
|
| 160 |
+
|
| 161 |
+
# Image(PIL) --> Image(CV2)[]
|
| 162 |
+
def extract_pieces_from_image_board(self, image):
|
| 163 |
+
# Load the image
|
| 164 |
+
if image is None:
|
| 165 |
+
print(f"Image not provided")
|
| 166 |
+
return
|
| 167 |
+
# Convert to grayscale
|
| 168 |
+
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
| 169 |
+
|
| 170 |
+
# Preprocessing
|
| 171 |
+
equ = cv2.equalizeHist(gray)
|
| 172 |
+
norm_image = equ.astype(np.float32) / 255.0
|
| 173 |
+
|
| 174 |
+
# Compute the gradients
|
| 175 |
+
grad_x = self.gradientx(norm_image)
|
| 176 |
+
grad_y = self.gradienty(norm_image)
|
| 177 |
+
|
| 178 |
+
# Clip the gradients
|
| 179 |
+
Dx_pos = np.clip(grad_x, 0, None)
|
| 180 |
+
Dx_neg = np.clip(-grad_x, 0, None)
|
| 181 |
+
Dy_pos = np.clip(grad_y, 0, None)
|
| 182 |
+
Dy_neg = np.clip(-grad_y, 0, None)
|
| 183 |
+
|
| 184 |
+
# Compute the Hough transform
|
| 185 |
+
hough_Dx = (np.sum(Dx_pos, axis=0) * np.sum(Dx_neg, axis=0)) / (norm_image.shape[0] ** 2)
|
| 186 |
+
hough_Dy = (np.sum(Dy_pos, axis=1) * np.sum(Dy_neg, axis=1)) / (norm_image.shape[1] ** 2)
|
| 187 |
+
|
| 188 |
+
# Adaptive thresholding
|
| 189 |
+
a = 1
|
| 190 |
+
is_match = False
|
| 191 |
+
lines_x = []
|
| 192 |
+
lines_y = []
|
| 193 |
+
|
| 194 |
+
while a < 5:
|
| 195 |
+
threshold_x = np.max(hough_Dx) * (a / 5.0)
|
| 196 |
+
threshold_y = np.max(hough_Dy) * (a / 5.0)
|
| 197 |
+
|
| 198 |
+
lines_x, lines_y, is_match = self.getChessLines(hough_Dx, hough_Dy, threshold_x, threshold_y,
|
| 199 |
+
norm_image.shape)
|
| 200 |
+
|
| 201 |
+
if is_match:
|
| 202 |
+
break
|
| 203 |
+
else:
|
| 204 |
+
a += 1
|
| 205 |
+
|
| 206 |
+
squares_resized = []
|
| 207 |
+
if is_match:
|
| 208 |
+
squares = self.getChessTiles(gray, lines_x, lines_y)
|
| 209 |
+
for square in squares:
|
| 210 |
+
resized = cv2.resize(square, (32, 32), interpolation=cv2.INTER_AREA)
|
| 211 |
+
squares_resized.append(resized)
|
| 212 |
+
|
| 213 |
+
#print("7 horizontal and vertical lines found, slicing up squares")
|
| 214 |
+
#squares = self.getChessTiles(gray, lines_x, lines_y)
|
| 215 |
+
#print(f"Tiles generated: ({squares[0].shape[0]}x{squares[0].shape[1]}) * {len(squares)}")
|
| 216 |
+
|
| 217 |
+
# Extract filename and FEN (assuming filename is FEN)
|
| 218 |
+
#img_save_dir = os.path.join("/mnt/c/Users/krzsa/IdeaProjects/Agents-Course-Assignment/chess-pieces")
|
| 219 |
+
|
| 220 |
+
#letters = "ABCDEFGH"
|
| 221 |
+
#for i, square in enumerate(squares):
|
| 222 |
+
# filename = f"fen_{letters[i % 8]}{(i // 8) + 1}.png"
|
| 223 |
+
# save_path = os.path.join(img_save_dir, filename)
|
| 224 |
+
# if i % 8 == 0:
|
| 225 |
+
# print(f"#{i}: saving {save_path}...")
|
| 226 |
+
# # Resize to 32x32 and save
|
| 227 |
+
# resized = cv2.resize(square, (32, 32), interpolation=cv2.INTER_AREA)
|
| 228 |
+
# cv2.imwrite(save_path, resized)
|
| 229 |
+
|
| 230 |
+
return squares_resized
|
| 231 |
+
|
| 232 |
+
def detect_chess_pieces(self, images):
|
| 233 |
+
recognition = ChessPiecesRecognition()
|
| 234 |
+
return recognition.classify_pieces(images)
|
| 235 |
+
|
| 236 |
+
def convert_pieces_list_to_fen(self, pieces):
|
| 237 |
+
print()
|
| 238 |
+
|
| 239 |
+
def forward(self, img: Image) -> str:
|
| 240 |
+
print(f"***KS*** Analyzing chess board image")
|
| 241 |
+
cv2_image = cv2.cvtColor(numpy.array(img), cv2.COLOR_RGB2BGR)
|
| 242 |
+
|
| 243 |
+
# Image(PIL) -> Image(CV2)(32x32) []
|
| 244 |
+
squares_resized = self.extract_pieces_from_image_board(cv2_image)
|
| 245 |
+
|
| 246 |
+
# Image(CV2)(32x32) [] -> str[]
|
| 247 |
+
pieces_list = self.detect_chess_pieces(squares_resized)
|
| 248 |
+
|
| 249 |
+
# str[] -> str(FEN)
|
| 250 |
+
fen = self.convert_pieces_list_to_fen(pieces_list)
|
| 251 |
+
|
| 252 |
+
return f"FEN is: \"{fen}\n "
|
chess_pieces_detection/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from train_chess_pieces_recognition import ChessPiecesRecognition
|
chess_pieces_detection/train.sh
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
. ../venv/bin/activate
|
| 2 |
+
python3 train_chess_pieces_recognition.py
|
chess_pieces_detection/train_chess_pieces_recognition.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torch.optim as optim
|
| 5 |
+
from torch.utils.data import Dataset, DataLoader
|
| 6 |
+
import os
|
| 7 |
+
import numpy as np
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# https://en.wikipedia.org/wiki/Convolution
|
| 12 |
+
# https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
| 13 |
+
# https://en.wikipedia.org/wiki/Forsyth%E2%80%93Edwards_Notation
|
| 14 |
+
|
| 15 |
+
# piece types
|
| 16 |
+
# - white Rook R
|
| 17 |
+
# - white Knights N
|
| 18 |
+
# - white Bishop B
|
| 19 |
+
# - white Queen Q
|
| 20 |
+
# - white King K
|
| 21 |
+
# - white Pawn P
|
| 22 |
+
# - black Rook r
|
| 23 |
+
# - black Knights n
|
| 24 |
+
# - black Bishop b
|
| 25 |
+
# - black Queen q
|
| 26 |
+
# - black King k
|
| 27 |
+
# - black Pawn p
|
| 28 |
+
# - empty
|
| 29 |
+
|
| 30 |
+
TRAIN_DIR = "/mnt/c/Users/krzsa/IdeaProjects/Agents-Course-Assignment/chess_pieces_detection/train-data"
|
| 31 |
+
TRAIN_DIR_BLACK = f"{TRAIN_DIR}/black"
|
| 32 |
+
TRAIN_DIR_WHITE = f"{TRAIN_DIR}/white"
|
| 33 |
+
TRAIN_DIR_EMPTY = f"{TRAIN_DIR}/empty"
|
| 34 |
+
|
| 35 |
+
#
|
| 36 |
+
# 0: 1
|
| 37 |
+
# 1: K
|
| 38 |
+
# 2: Q
|
| 39 |
+
# 3: R
|
| 40 |
+
# 4: B
|
| 41 |
+
# 5: N
|
| 42 |
+
# 6: P
|
| 43 |
+
# 7: k
|
| 44 |
+
# 8: q
|
| 45 |
+
# 9: r
|
| 46 |
+
# 10: b
|
| 47 |
+
# 11: n
|
| 48 |
+
# 12: p
|
| 49 |
+
TRAIN_DATA = [
|
| 50 |
+
(f"{TRAIN_DIR_EMPTY}/1_001.png", "1"),
|
| 51 |
+
(f"{TRAIN_DIR_EMPTY}/1_002.png", "1"),
|
| 52 |
+
(f"{TRAIN_DIR_BLACK}/b_001.png", "b"),
|
| 53 |
+
(f"{TRAIN_DIR_BLACK}/b_002.png", "b"),
|
| 54 |
+
(f"{TRAIN_DIR_BLACK}/k_001.png", "k"),
|
| 55 |
+
(f"{TRAIN_DIR_BLACK}/k_002.png", "k"),
|
| 56 |
+
(f"{TRAIN_DIR_BLACK}/n_001.png", "n"),
|
| 57 |
+
(f"{TRAIN_DIR_BLACK}/n_002.png", "n"),
|
| 58 |
+
(f"{TRAIN_DIR_BLACK}/p_001.png", "p"),
|
| 59 |
+
(f"{TRAIN_DIR_BLACK}/p_002.png", "p"),
|
| 60 |
+
(f"{TRAIN_DIR_BLACK}/q_001.png", "q"),
|
| 61 |
+
(f"{TRAIN_DIR_BLACK}/q_002.png", "q"),
|
| 62 |
+
(f"{TRAIN_DIR_BLACK}/r_001.png", "r"),
|
| 63 |
+
(f"{TRAIN_DIR_BLACK}/r_002.png", "r"),
|
| 64 |
+
(f"{TRAIN_DIR_WHITE}/B_001.png", "B"),
|
| 65 |
+
(f"{TRAIN_DIR_WHITE}/B_002.png", "B"),
|
| 66 |
+
(f"{TRAIN_DIR_WHITE}/K_001.png", "K"),
|
| 67 |
+
(f"{TRAIN_DIR_WHITE}/K_002.png", "K"),
|
| 68 |
+
(f"{TRAIN_DIR_WHITE}/N_001.png", "N"),
|
| 69 |
+
(f"{TRAIN_DIR_WHITE}/N_002.png", "N"),
|
| 70 |
+
(f"{TRAIN_DIR_WHITE}/P_001.png", "P"),
|
| 71 |
+
(f"{TRAIN_DIR_WHITE}/P_002.png", "P"),
|
| 72 |
+
(f"{TRAIN_DIR_WHITE}/Q_001.png", "Q"),
|
| 73 |
+
(f"{TRAIN_DIR_WHITE}/Q_002.png", "Q"),
|
| 74 |
+
(f"{TRAIN_DIR_WHITE}/R_001.png", "R"),
|
| 75 |
+
(f"{TRAIN_DIR_WHITE}/R_002.png", "R"),
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
TEST_DATA = TRAIN_DATA
|
| 79 |
+
|
| 80 |
+
# https://docs.pytorch.org/docs/stable/nn.html
|
| 81 |
+
# https://docs.pytorch.org/docs/stable/optim.html
|
| 82 |
+
# https://docs.pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module
|
| 83 |
+
class CNNModel(nn.Module):
|
| 84 |
+
|
| 85 |
+
def __init__(self, _name):
|
| 86 |
+
super(CNNModel, self).__init__()
|
| 87 |
+
self.name = _name
|
| 88 |
+
print("***KS*** Model: Creating layers")
|
| 89 |
+
# First Convolutional Layer: 32 features, 5x5 kernel
|
| 90 |
+
# https://docs.pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d
|
| 91 |
+
# https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
| 92 |
+
self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, padding=2)
|
| 93 |
+
# Second Convolutional Layer: 64 features, 5x5 kernel
|
| 94 |
+
self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2)
|
| 95 |
+
# Fully connected layer
|
| 96 |
+
# https://docs.pytorch.org/docs/stable/generated/torch.nn.Linear.html#torch.nn.Linear
|
| 97 |
+
# 64 because last convolution had 64 channels
|
| 98 |
+
# 8 x 8 because 2 pool2d calculations will reduce 32 x 32 --> 16 x 16 --> 8 x 8
|
| 99 |
+
self.fc1 = nn.Linear(8 * 8 * 64, 1024)
|
| 100 |
+
self.dropout = nn.Dropout(p=0.5) # Changed from 0.3 to 0.5
|
| 101 |
+
# Output layer
|
| 102 |
+
self.fc2 = nn.Linear(1024, 13)
|
| 103 |
+
|
| 104 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 105 |
+
|
| 106 |
+
# Initialize weights and biases
|
| 107 |
+
self._initialize_weights()
|
| 108 |
+
|
| 109 |
+
def _initialize_weights(self):
|
| 110 |
+
# Load the pre-trained model
|
| 111 |
+
model_name = f"saved_models/{self.name}.pth"
|
| 112 |
+
print(f"***KS*** Checking pre-trained model: '{model_name}'")
|
| 113 |
+
if os.path.exists(model_name):
|
| 114 |
+
print(f"***KS*** Model '{model_name}' exists, loading weights ...")
|
| 115 |
+
self.load_state_dict(torch.load(model_name, map_location=self.device))
|
| 116 |
+
print("*** KS *** Model loaded.")
|
| 117 |
+
else:
|
| 118 |
+
print(f"*** KS *** Model file '{model_name}' not found. Initializing weights with random values")
|
| 119 |
+
# Initialize weights with truncated normal (approximate with normal and clamp)
|
| 120 |
+
nn.init.trunc_normal_(self.conv1.weight, std=0.1)
|
| 121 |
+
nn.init.constant_(self.conv1.bias, 0.1)
|
| 122 |
+
|
| 123 |
+
nn.init.trunc_normal_(self.conv2.weight, std=0.1)
|
| 124 |
+
nn.init.constant_(self.conv2.bias, 0.1)
|
| 125 |
+
|
| 126 |
+
nn.init.trunc_normal_(self.fc1.weight, std=0.1)
|
| 127 |
+
nn.init.constant_(self.fc1.bias, 0.1)
|
| 128 |
+
|
| 129 |
+
nn.init.trunc_normal_(self.fc2.weight, std=0.1)
|
| 130 |
+
nn.init.constant_(self.fc2.bias, 0.1)
|
| 131 |
+
|
| 132 |
+
self.to(self.device)
|
| 133 |
+
|
| 134 |
+
def save_weights(self):
|
| 135 |
+
print(f"***KS*** Saving model ...")
|
| 136 |
+
# Save the model checkpoint
|
| 137 |
+
os.makedirs('saved_models', exist_ok=True)
|
| 138 |
+
model_save_path = f"saved_models/{self.name}.pth"
|
| 139 |
+
torch.save(self.state_dict(), model_save_path)
|
| 140 |
+
print(f'*** KS *** Model saved in file: {model_save_path}')
|
| 141 |
+
|
| 142 |
+
# Define the computation performed at every call.
|
| 143 |
+
# Should be overridden by all subclasses.
|
| 144 |
+
def forward(self, x):
|
| 145 |
+
print("***KS*** Model: Executing forward calculations")
|
| 146 |
+
# Apply first convolutional layer + ReLU activation
|
| 147 |
+
|
| 148 |
+
print(f"***KS*** [0] {x.shape}")
|
| 149 |
+
# [26, 1, 32, 32]
|
| 150 |
+
x = self.conv1(x)
|
| 151 |
+
print(f"***KS*** [1] {x.shape}")
|
| 152 |
+
# [26, 32, 32, 32] 26 - number of images, first 32 number of convolutions
|
| 153 |
+
# --> 32 channels
|
| 154 |
+
# --> each channel is [x,x] size
|
| 155 |
+
# https://docs.pytorch.org/docs/stable/generated/torch.nn.ReLU.html#torch.nn.ReLU
|
| 156 |
+
x = F.relu(x)
|
| 157 |
+
print(f"***KS*** [2] {x.shape}")
|
| 158 |
+
# [26, 32, 32, 32]
|
| 159 |
+
# https://docs.pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
|
| 160 |
+
x = F.max_pool2d(x, kernel_size=2, stride=2) # First pooling
|
| 161 |
+
print(f"***KS*** [3] {x.shape}")
|
| 162 |
+
# [26, 32, 16, 16]
|
| 163 |
+
# Apply second convolutional layer + ReLU activation
|
| 164 |
+
x = F.relu(self.conv2(x))
|
| 165 |
+
print(f"***KS*** [4] {x.shape}")
|
| 166 |
+
# [26, 64, 16, 16]
|
| 167 |
+
x = F.max_pool2d(x, kernel_size=2, stride=2) # Second pooling
|
| 168 |
+
# --> 32 channels
|
| 169 |
+
# --> each channel is [x/2 , x/2]
|
| 170 |
+
print(f"***KS*** [5] {x.shape}")
|
| 171 |
+
# [26, 64, 8, 8]
|
| 172 |
+
|
| 173 |
+
# Flatten the tensor
|
| 174 |
+
# https://docs.pytorch.org/docs/stable/tensor_view.html
|
| 175 |
+
# https://docs.pytorch.org/docs/stable/generated/torch.Tensor.view.html#torch.Tensor.view
|
| 176 |
+
x = x.view(-1, 8 * 8 * 64)
|
| 177 |
+
print(f"***KS*** [6] {x.shape}")
|
| 178 |
+
# [26, 4096]
|
| 179 |
+
# --> first dimension inferred from existing dimensions and from the second dimension below
|
| 180 |
+
# --> second dimensions 8*8*64 = 4096
|
| 181 |
+
|
| 182 |
+
# Fully connected layer + ReLU activation
|
| 183 |
+
x = self.fc1(x)
|
| 184 |
+
print(f"***KS*** [7] {x.shape}")
|
| 185 |
+
# [26, 1024]
|
| 186 |
+
# input [?, 4096]
|
| 187 |
+
# output [?, 1024]
|
| 188 |
+
x = F.relu(x)
|
| 189 |
+
print(f"***KS*** [8] {x.shape}")
|
| 190 |
+
# [26, 1024]
|
| 191 |
+
|
| 192 |
+
# Apply dropout
|
| 193 |
+
x = self.dropout(x)
|
| 194 |
+
print(f"***KS*** [9] {x.shape}")
|
| 195 |
+
# [26, 1024]
|
| 196 |
+
|
| 197 |
+
# Output layer (no activation, as CrossEntropyLoss applies Softmax internally)
|
| 198 |
+
x = self.fc2(x)
|
| 199 |
+
print(f"***KS*** [10] {x.shape}")
|
| 200 |
+
# [26, 13]
|
| 201 |
+
# input [?, 1024]
|
| 202 |
+
# output [?, 13]
|
| 203 |
+
return x
|
| 204 |
+
|
| 205 |
+
def get_device(self):
|
| 206 |
+
return self.device
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
# Dataset class for PyTorch
|
| 210 |
+
# https://docs.pytorch.org/docs/stable/data.html#torch.utils.data.Dataset
|
| 211 |
+
class ChessDataset(Dataset):
|
| 212 |
+
CHESS_PIECES = '1KQRBNPkqrbnp'
|
| 213 |
+
|
| 214 |
+
def __init__(self, image_train_date):
|
| 215 |
+
|
| 216 |
+
#self.image_filepaths = image_filepaths
|
| 217 |
+
self.num_images = len(image_train_date)
|
| 218 |
+
# Each tile is a 32x32 grayscale image
|
| 219 |
+
self.images = np.zeros([self.num_images, 32, 32], dtype=np.uint8)
|
| 220 |
+
self.labels = np.zeros([self.num_images], dtype=np.int64) # Store labels as integers
|
| 221 |
+
|
| 222 |
+
for i, image_file_path_and_label in enumerate(image_train_date):
|
| 223 |
+
# Load Image
|
| 224 |
+
with Image.open(image_file_path_and_label[0]) as img:
|
| 225 |
+
img = img.convert('L') # Ensure image is in grayscale
|
| 226 |
+
self.images[i, :, :] = np.array(img, dtype=np.uint8)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
self.labels[i] = self.__get_piece_index_from_label__(image_file_path_and_label[1])
|
| 230 |
+
|
| 231 |
+
print("***KS*** Done loading training data")
|
| 232 |
+
|
| 233 |
+
def __get_piece_index_from_label__(self, label) -> int:
|
| 234 |
+
return self.CHESS_PIECES.find(label)
|
| 235 |
+
|
| 236 |
+
def get_piece_label(self, idx) -> str:
|
| 237 |
+
return self.CHESS_PIECES[idx]
|
| 238 |
+
|
| 239 |
+
def __len__(self):
|
| 240 |
+
return self.num_images
|
| 241 |
+
|
| 242 |
+
# required to be implemented
|
| 243 |
+
# returns an item for given key
|
| 244 |
+
def __getitem__(self, idx):
|
| 245 |
+
image = self.images[idx].astype('float32') / 255.0 # Normalize
|
| 246 |
+
image = np.expand_dims(image, axis=0) # Add channel dimension
|
| 247 |
+
label = self.labels[idx]
|
| 248 |
+
return torch.tensor(image, dtype=torch.float32), label
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class ChessImagesDataset(Dataset):
|
| 252 |
+
def __init__(self, images):
|
| 253 |
+
self.num_images = len(images)
|
| 254 |
+
self.images = images
|
| 255 |
+
|
| 256 |
+
def __len__(self):
|
| 257 |
+
return self.num_images
|
| 258 |
+
|
| 259 |
+
def get_piece_label(self, idx) -> str:
|
| 260 |
+
return self.CHESS_PIECES[idx]
|
| 261 |
+
|
| 262 |
+
# required to be implemented
|
| 263 |
+
# returns an item for given key
|
| 264 |
+
def __getitem__(self, idx):
|
| 265 |
+
image = self.images[idx].astype('float32') / 255.0 # Normalize
|
| 266 |
+
image = np.expand_dims(image, axis=0) # Add channel dimension
|
| 267 |
+
label = "" # not needed
|
| 268 |
+
return torch.tensor(image, dtype=torch.float32), label
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class ChessPiecesRecognition:
|
| 272 |
+
def __init__(self):
|
| 273 |
+
print(f"***KS*** Chess pieces recognition initialized")
|
| 274 |
+
self.model = CNNModel("test-1")
|
| 275 |
+
self.__load_train_data__()
|
| 276 |
+
|
| 277 |
+
def __load_train_data__(self):
|
| 278 |
+
print(f"*** KS *** loading training data")
|
| 279 |
+
# Load training dataset
|
| 280 |
+
# Data loader combines a dataset and a sampler, and provides an iterable over the given dataset.
|
| 281 |
+
print(f"Loading {len(TRAIN_DATA)} Training tiles", end='')
|
| 282 |
+
train_dataset = ChessDataset(TRAIN_DATA)
|
| 283 |
+
|
| 284 |
+
# Load testing dataset
|
| 285 |
+
print(f"\n*** KS *** Loading {len(TEST_DATA)} Testing tiles", end='')
|
| 286 |
+
test_dataset = ChessDataset(TEST_DATA)
|
| 287 |
+
print()
|
| 288 |
+
|
| 289 |
+
batch_size = 64 # @param {type:"number"}
|
| 290 |
+
# https://docs.pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
|
| 291 |
+
self.train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
| 292 |
+
self.test_loader = DataLoader(test_dataset, batch_size=batch_size)
|
| 293 |
+
|
| 294 |
+
def train(self):
|
| 295 |
+
print(f"***KS*** Training chess pieces recognition")
|
| 296 |
+
|
| 297 |
+
# Define loss function and optimizer
|
| 298 |
+
# https://docs.pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html#torch.nn.CrossEntropyLoss
|
| 299 |
+
criterion = nn.CrossEntropyLoss() # For multi-class classification
|
| 300 |
+
|
| 301 |
+
# https://docs.pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.parameters
|
| 302 |
+
# https://docs.pytorch.org/docs/stable/optim.html
|
| 303 |
+
# https://docs.pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam
|
| 304 |
+
optimizer = optim.Adam(self.model.parameters(), lr=1e-4)
|
| 305 |
+
|
| 306 |
+
# Move model to GPU if available
|
| 307 |
+
|
| 308 |
+
# Set training parameters
|
| 309 |
+
do_training = True # Set to True to train the model
|
| 310 |
+
epochs = 100 # @param {type:"number"}
|
| 311 |
+
|
| 312 |
+
if do_training:
|
| 313 |
+
# Training loop
|
| 314 |
+
self.model.train()
|
| 315 |
+
print(f"*** KS *** Starting training for {epochs} epochs...")
|
| 316 |
+
for epoch in range(epochs):
|
| 317 |
+
running_loss = 0.0
|
| 318 |
+
print(f"***KS*** Epoch: {epoch}")
|
| 319 |
+
for i, (inputs, labels) in enumerate(self.train_loader):
|
| 320 |
+
# Move inputs and labels to device
|
| 321 |
+
inputs = inputs.to(self.model.get_device())
|
| 322 |
+
labels = labels.to(self.model.get_device())
|
| 323 |
+
|
| 324 |
+
# Zero the parameter gradients
|
| 325 |
+
optimizer.zero_grad()
|
| 326 |
+
|
| 327 |
+
# Forward pass
|
| 328 |
+
outputs = self.model(inputs)
|
| 329 |
+
loss = criterion(outputs, labels)
|
| 330 |
+
|
| 331 |
+
# Backward pass and optimize
|
| 332 |
+
loss.backward()
|
| 333 |
+
optimizer.step()
|
| 334 |
+
|
| 335 |
+
# Print statistics
|
| 336 |
+
running_loss += loss.item()
|
| 337 |
+
if (i + 1) % 10 == 0: # Print every 10 batches
|
| 338 |
+
print(f'*** KS *** Epoch [{epoch +1}/{epochs}], Step [{i +1}/{len(self.train_loader)}], '
|
| 339 |
+
f'Loss: {running_loss / 10:.4f}')
|
| 340 |
+
running_loss = 0.0
|
| 341 |
+
|
| 342 |
+
print('Finished Training')
|
| 343 |
+
|
| 344 |
+
self.model.save_weights()
|
| 345 |
+
|
| 346 |
+
def eval(self):
|
| 347 |
+
# Evaluate the model on the testing dataset
|
| 348 |
+
self.model.eval() # Set model to evaluation mode
|
| 349 |
+
correct = 0
|
| 350 |
+
total = 0
|
| 351 |
+
with torch.no_grad():
|
| 352 |
+
for inputs, labels in self.test_loader:
|
| 353 |
+
# Move inputs and labels to device
|
| 354 |
+
inputs = inputs.to(self.model.get_device())
|
| 355 |
+
labels = labels.to(self.model.get_device())
|
| 356 |
+
|
| 357 |
+
outputs = self.model(inputs)
|
| 358 |
+
print(f"***KS*** Got model outputs: \nshape: {outputs.shape}\n{outputs}")
|
| 359 |
+
|
| 360 |
+
labels_detected = np.argmax(outputs.cpu(), axis=1)
|
| 361 |
+
print(f"***KS*** Got labels idx detected: \nshape: {labels_detected.shape}\n{labels_detected}")
|
| 362 |
+
|
| 363 |
+
_, predicted = torch.max(outputs.data, 1)
|
| 364 |
+
total += labels.size(0)
|
| 365 |
+
correct += (predicted == labels).sum().item()
|
| 366 |
+
|
| 367 |
+
test_accuracy = correct / total
|
| 368 |
+
print(f'Accuracy on test set: {test_accuracy * 100:.2f}%\n')
|
| 369 |
+
|
| 370 |
+
def classify_pieces(self, images):
|
| 371 |
+
dataset = ChessImagesDataset(images)
|
| 372 |
+
loader = DataLoader(dataset, batch_size=64)
|
| 373 |
+
|
| 374 |
+
# Evaluate the model on the testing dataset
|
| 375 |
+
labels_str = ""
|
| 376 |
+
self.model.eval() # Set model to evaluation mode
|
| 377 |
+
with torch.no_grad():
|
| 378 |
+
for inputs, labels in loader:
|
| 379 |
+
# Move inputs and labels to device
|
| 380 |
+
inputs = inputs.to(self.model.get_device())
|
| 381 |
+
labels = labels.to(self.model.get_device())
|
| 382 |
+
|
| 383 |
+
outputs = self.model(inputs)
|
| 384 |
+
print(f"***KS*** Got model outputs: \nshape: {outputs.shape}\n{outputs}")
|
| 385 |
+
|
| 386 |
+
labels_detected = np.argmax(outputs.cpu(), axis=1)
|
| 387 |
+
print(f"***KS*** Got labels idx detected: \nshape: {labels_detected.shape}\n{labels_detected}")
|
| 388 |
+
|
| 389 |
+
labels = [dataset.get_piece_label(ix) for ix in labels_detected]
|
| 390 |
+
labels_str = ''.join(labels)
|
| 391 |
+
|
| 392 |
+
return labels_str
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
#t = ChessPiecesRecognition()
|
| 396 |
+
#t.train()
|
| 397 |
+
#t.eval()
|
| 398 |
+
|
| 399 |
+
if __name__ == "__main__":
|
| 400 |
+
print("This is a module and should not be executed directly")
|
install-requirements.sh
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
. ./venv/bin/activate
|
| 2 |
+
pip install -r ./requirements.txt
|
my_prompt_config.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from smolagents import PromptTemplates, PlanningPromptTemplate, FinalAnswerPromptTemplate, ManagedAgentPromptTemplate
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class PromptConfig:
|
| 5 |
+
PROMPT_TEMPLATES = PromptTemplates(
|
| 6 |
+
system_prompt="""
|
| 7 |
+
You are a general AI assistant. I will ask you a question. Report your thoughts, and finish
|
| 8 |
+
your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER].
|
| 9 |
+
Describe your initial plan as a set of bullet points.
|
| 10 |
+
Each bullet point should describe in one sentence an action which is to be taken in this step.
|
| 11 |
+
Use the tools provided. If you are going to use a tool, describe in detail why you are going
|
| 12 |
+
to use that particular tool and explain parameters used to invoke the tool.
|
| 13 |
+
Analyze the question provided.
|
| 14 |
+
Describe each step which needs to be taken to answer it.
|
| 15 |
+
""",
|
| 16 |
+
planning=PlanningPromptTemplate(
|
| 17 |
+
initial_plan="""
|
| 18 |
+
|
| 19 |
+
""",
|
| 20 |
+
update_plan_pre_messages="""
|
| 21 |
+
|
| 22 |
+
""",
|
| 23 |
+
update_plan_post_messages="""
|
| 24 |
+
|
| 25 |
+
""",
|
| 26 |
+
),
|
| 27 |
+
managed_agent=ManagedAgentPromptTemplate(task="", report=""),
|
| 28 |
+
final_answer=FinalAnswerPromptTemplate(
|
| 29 |
+
pre_messages="",
|
| 30 |
+
post_messages="""
|
| 31 |
+
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of
|
| 32 |
+
numbers and/or strings.
|
| 33 |
+
If you are asked for a number, don’t use comma to write your number neither use units such as $ or percent
|
| 34 |
+
sign unless specified otherwise.
|
| 35 |
+
If you are asked for a string, don’t use articles, neither abbreviations (e.g. for cities), and write the digits in
|
| 36 |
+
plain text unless specified otherwise.
|
| 37 |
+
If you are asked for a comma separated list, apply the above rules depending of whether the element to be put
|
| 38 |
+
in the list is a number or a string
|
| 39 |
+
"""
|
| 40 |
+
),
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
def __init__(self):
|
| 44 |
+
print("Prompt Templates initialized")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
#EMPTY_PROMPT_TEMPLATES = PromptTemplates(
|
| 49 |
+
# system_prompt="",
|
| 50 |
+
# planning=PlanningPromptTemplate(
|
| 51 |
+
# initial_plan="",
|
| 52 |
+
# update_plan_pre_messages="",
|
| 53 |
+
# update_plan_post_messages="",
|
| 54 |
+
# ),
|
| 55 |
+
# managed_agent=ManagedAgentPromptTemplate(task="", report=""),
|
| 56 |
+
# final_answer=FinalAnswerPromptTemplate(pre_messages="", post_messages=""),
|
| 57 |
+
#)
|
my_tools.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from smolagents import Tool
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import requests
|
| 4 |
+
from io import BytesIO
|
| 5 |
+
|
| 6 |
+
#AUTHORIZED_TYPES = [
|
| 7 |
+
# "string",
|
| 8 |
+
# "boolean",
|
| 9 |
+
# "integer",
|
| 10 |
+
# "number",
|
| 11 |
+
# "image",
|
| 12 |
+
# "audio",
|
| 13 |
+
# "array",
|
| 14 |
+
# "object",
|
| 15 |
+
# "any",
|
| 16 |
+
# "null",
|
| 17 |
+
#]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class ReverseStringTool(Tool):
|
| 21 |
+
name = "_my_reverse_string"
|
| 22 |
+
description = """
|
| 23 |
+
Decode a string which is provided in a reversed form.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
inputs = {
|
| 27 |
+
"_inp": {
|
| 28 |
+
"type": "string",
|
| 29 |
+
"description": "encoded input string",
|
| 30 |
+
}
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
output_type = "string"
|
| 34 |
+
|
| 35 |
+
def forward(self, _inp: str) -> str:
|
| 36 |
+
_out = ""
|
| 37 |
+
for a in _inp:
|
| 38 |
+
_out = a + _out
|
| 39 |
+
return _out
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class ImageLoadTool(Tool):
|
| 43 |
+
name = "_my_image_load"
|
| 44 |
+
description = """
|
| 45 |
+
Load image for the provided task id
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
inputs = {
|
| 49 |
+
"task_id": {
|
| 50 |
+
"type": "string",
|
| 51 |
+
"description": "task id to load image",
|
| 52 |
+
}
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
output_type = "image"
|
| 56 |
+
api_url = "https://agents-course-unit4-scoring.hf.space"
|
| 57 |
+
|
| 58 |
+
def forward(self, task_id: str) -> Image:
|
| 59 |
+
headers = {
|
| 60 |
+
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0 Safari/537.36"
|
| 61 |
+
}
|
| 62 |
+
url = f"{self.api_url}/files/{task_id}"
|
| 63 |
+
response = requests.get(url, headers=headers)
|
| 64 |
+
image = Image.open(BytesIO(response.content)).convert("RGB")
|
| 65 |
+
print(f"***KS*** Loaded image for \n\ttask id: {task_id} \n\timage: {image}")
|
| 66 |
+
return image
|
requirements.txt
CHANGED
|
@@ -1,2 +1,10 @@
|
|
| 1 |
gradio
|
| 2 |
-
requests
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
gradio
|
| 2 |
+
requests
|
| 3 |
+
smolagents
|
| 4 |
+
gradio[oauth]
|
| 5 |
+
pytest
|
| 6 |
+
matplotlib
|
| 7 |
+
PyQt6
|
| 8 |
+
chess
|
| 9 |
+
opencv-python
|
| 10 |
+
torch
|
run.sh
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
. ./venv/bin/activate
|
| 2 |
+
python3 app.py
|
simple.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import requests
|
| 4 |
+
import inspect
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from smolagents import CodeAgent, tool, InferenceClientModel, WebSearchTool, load_tool, PromptTemplates, Tool, FinalAnswerTool
|
| 7 |
+
from smolagents import PromptTemplates, PlanningPromptTemplate, FinalAnswerPromptTemplate, ManagedAgentPromptTemplate
|
| 8 |
+
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
from my_tools import ReverseStringTool, ImageLoadTool
|
| 11 |
+
from chess_board_tool import ChessBoard
|
| 12 |
+
from PIL import Image
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
task_id = "cca530fc-4052-43b2-b130-b30968d8aa44"
|
| 16 |
+
|
| 17 |
+
# https://github.com/kratos606/chessboard-recogniser/tree/main
|
| 18 |
+
|
| 19 |
+
MODEL_REASONING = "Qwen/Qwen2.5-Coder-32B-Instruct"
|
| 20 |
+
#MODEL_REASONING = "Qwen/Qwen2.5-72B-Instruct" not good
|
| 21 |
+
#"meta-llama/Meta-Llama-3-70B-Instruct"
|
| 22 |
+
# jayasuryajsk/chess-reasoner-qwen
|
| 23 |
+
# https://huggingface.co/jayasuryajsk/chess-reasoner-qwen
|
| 24 |
+
|
| 25 |
+
PROMPT_TEMPLATES = PromptTemplates(
|
| 26 |
+
system_prompt="""
|
| 27 |
+
You are a general AI assistant.
|
| 28 |
+
|
| 29 |
+
Answer the following questions as best you can.
|
| 30 |
+
|
| 31 |
+
Describe your initial plan as a set of bullet points.
|
| 32 |
+
Each bullet point should describe in one sentence an action which is to be taken in this step.
|
| 33 |
+
|
| 34 |
+
Use the tools provided. If you are going to use a tool, describe in detail how you are going
|
| 35 |
+
to use that particular tool and explain parameters used to invoke the tool.
|
| 36 |
+
|
| 37 |
+
Tools provided: final_answer , _my_reverse_string , _my_image_load, _my_chess_board
|
| 38 |
+
|
| 39 |
+
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of
|
| 40 |
+
numbers and/or strings.
|
| 41 |
+
If you are asked for a number, don’t use comma to write your number neither use units such as $ or percent
|
| 42 |
+
sign unless specified otherwise.
|
| 43 |
+
If you are asked for a string, don’t use articles, neither abbreviations (e.g. for cities), and write the digits in
|
| 44 |
+
plain text unless specified otherwise.
|
| 45 |
+
If you are asked for a comma separated list, apply the above rules depending of whether the element to be put
|
| 46 |
+
in the list is a number or a string.
|
| 47 |
+
|
| 48 |
+
Report your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER].
|
| 49 |
+
|
| 50 |
+
""",
|
| 51 |
+
|
| 52 |
+
planning=PlanningPromptTemplate(
|
| 53 |
+
initial_plan="""
|
| 54 |
+
|
| 55 |
+
""",
|
| 56 |
+
update_plan_pre_messages="""
|
| 57 |
+
|
| 58 |
+
""",
|
| 59 |
+
update_plan_post_messages="""
|
| 60 |
+
|
| 61 |
+
""",
|
| 62 |
+
),
|
| 63 |
+
managed_agent=ManagedAgentPromptTemplate(task="", report=""),
|
| 64 |
+
final_answer=FinalAnswerPromptTemplate(
|
| 65 |
+
pre_messages="",
|
| 66 |
+
post_messages="""
|
| 67 |
+
|
| 68 |
+
"""
|
| 69 |
+
),
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
#question = f"Load an image for task id {task_id} and describe the chess position shown on the image. "
|
| 73 |
+
#question = f"Load an image for task id {task_id} and display it using matplotlib "
|
| 74 |
+
question = f"Load an image for task id {task_id} and analyze the chess board "
|
| 75 |
+
|
| 76 |
+
reasoning_agent = CodeAgent(
|
| 77 |
+
name="CourseAssistant",
|
| 78 |
+
description="General AI Assistant",
|
| 79 |
+
tools=[ImageLoadTool(), FinalAnswerTool(), ReverseStringTool(), ChessBoard()],
|
| 80 |
+
model=InferenceClientModel(model_id=MODEL_REASONING),
|
| 81 |
+
planning_interval=3, # This is where you activate planning!,
|
| 82 |
+
prompt_templates=PROMPT_TEMPLATES,
|
| 83 |
+
#managed_agents=[web_search_agent],
|
| 84 |
+
additional_authorized_imports=["PIL","chess","my_tools","matplotlib","matplotlib.pyplot","chess_board_tool"],
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
reasoning_agent.run(question)
|
simple.sh
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
. ./venv/bin/activate
|
| 2 |
+
clear
|
| 3 |
+
python3 simple.py
|
test_tools.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from my_tools import ReverseStringTool, ImageLoadTool
|
| 2 |
+
from chess_board_tool import ChessBoard
|
| 3 |
+
import pytest
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import matplotlib as mp
|
| 6 |
+
|
| 7 |
+
#pytest --capture=no
|
| 8 |
+
|
| 9 |
+
@pytest.mark.parametrize("_inp,_exp",[("abc", "cba"),("ihg fed cba", "abc def ghi")])
|
| 10 |
+
def test_tool_reverse_string(_inp,_exp):
|
| 11 |
+
assert ReverseStringTool().forward(_inp) == _exp
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@pytest.mark.parametrize("_task_id,_exp",[("cca530fc-4052-43b2-b130-b30968d8aa44", "")])
|
| 15 |
+
def test_tool_image_load(_task_id,_exp):
|
| 16 |
+
#assert ReverseStringTool().forward(_inp) == _exp
|
| 17 |
+
print(f"Loading image for task id: {_task_id}")
|
| 18 |
+
t = ImageLoadTool()
|
| 19 |
+
result = t.forward(_task_id)
|
| 20 |
+
print(f"Got result: {result}")
|
| 21 |
+
mp.use('QtAgg')
|
| 22 |
+
#plt.imshow(result)
|
| 23 |
+
#plt.show()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@pytest.mark.parametrize("_task_id,_exp",[("cca530fc-4052-43b2-b130-b30968d8aa44", "")])
|
| 27 |
+
def test_tool_chess_board(_task_id,_exp):
|
| 28 |
+
#assert ReverseStringTool().forward(_inp) == _exp
|
| 29 |
+
print(f"Loading image for task id: {_task_id}")
|
| 30 |
+
t = ImageLoadTool()
|
| 31 |
+
image = t.forward(_task_id)
|
| 32 |
+
print(f"Got result: {image}")
|
| 33 |
+
board_tool = ChessBoard()
|
| 34 |
+
fen = board_tool.forward(image)
|