Falguni commited on
Commit
e048cbb
·
1 Parent(s): 80a5cdf

Add annotation support in natural language with LLMs

Browse files
src/llm/__init__.py ADDED
File without changes
src/llm/base_llm_wrapper.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import List
3
+ from ThinkSquare.src.llm.data_models.typeddict_data_models import MultiCommentModel
4
+ from src.llm.prompts.prompt_head import (
5
+ prompt_head_expert,
6
+ prompt_head_jarvis,
7
+ prompt_head_novice,
8
+ prompt_head_natural,
9
+ prompt_head_generic,
10
+ )
11
+ from src.llm.prompts.prompt_core import prompt_core
12
+
13
+
14
+ class BaseLLMWrapper(ABC):
15
+ def comment(
16
+ self,
17
+ character: str,
18
+ game,
19
+ comment_refs: List[int],
20
+ move_nums: List,
21
+ played_moves: List,
22
+ played_by: List,
23
+ comments: List,
24
+ move_suggestions: List,
25
+ pre_eval_scores: List,
26
+ post_eval_scores: List,
27
+ ) -> MultiCommentModel:
28
+ """
29
+ Rewrite the comment in a specific tone or style.
30
+
31
+ Args:
32
+ character (str): The character or style in which to rewrite the comment.
33
+ move_nums (List): List of move numbers corresponding to the comments.
34
+ comments (List): List of comments to be rewritten.
35
+ move_suggestions (List): List of suggested moves, if any.
36
+ Returns:
37
+ MultiCommentModel: A model containing the rewritten comments.
38
+
39
+ """
40
+ character = character.lower().strip()
41
+ if character == "natural":
42
+ prompt_head = prompt_head_natural
43
+ elif character == "jarvis":
44
+ prompt_head = prompt_head_jarvis
45
+ elif character == "novice":
46
+ prompt_head = prompt_head_novice
47
+ elif character == "expert":
48
+ prompt_head = prompt_head_expert
49
+ else:
50
+ prompt_head = prompt_head_generic.format(character=character)
51
+
52
+ comments_info = []
53
+ for (
54
+ comment_ref,
55
+ move_num,
56
+ played_move,
57
+ played_by_player,
58
+ comment,
59
+ suggestion,
60
+ pre_eval_score,
61
+ post_eval_score,
62
+ ) in zip(
63
+ comment_refs,
64
+ move_nums,
65
+ played_moves,
66
+ played_by,
67
+ comments,
68
+ move_suggestions,
69
+ pre_eval_scores,
70
+ post_eval_scores,
71
+ ):
72
+ comments_info.append(
73
+ {
74
+ "comment_ref": comment_ref,
75
+ "move_num": move_num,
76
+ "move": played_move,
77
+ "played_by": played_by_player,
78
+ "comment": comment,
79
+ "better_variation": suggestion,
80
+ "score_before_move": pre_eval_score,
81
+ "score_after_move": post_eval_score,
82
+ }
83
+ )
84
+
85
+ prompt = (
86
+ prompt_head + "\n" + prompt_core.format(pgn=game, comments=comments_info)
87
+ )
88
+
89
+ response = self.generate_response(prompt)
90
+
91
+ return response
92
+
93
+ @abstractmethod
94
+ def generate_response(self, prompt: str) -> MultiCommentModel:
95
+ """
96
+ Generate a response based on the provided prompt.
97
+
98
+ Args:
99
+ prompt (str): The input prompt to generate a response for.
100
+
101
+ Returns:
102
+ str: The generated response.
103
+ """
104
+ pass
src/llm/data_models/__init__.py ADDED
File without changes
src/llm/data_models/pydantic_data_models.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from pydantic import BaseModel, Field
3
+
4
+
5
+ class SingleCommentModel(BaseModel):
6
+ comment_ref: int = Field(description="Move number in the game")
7
+ comment: str = Field(description="Rewritten comment text for the move")
8
+
9
+
10
+ class MultiCommentModel(BaseModel):
11
+ comments: List[SingleCommentModel] = Field(
12
+ description="List of rewritten comments for each move (with comment) in the game",
13
+ )
14
+
15
+
16
+ from pydantic import TypeAdapter
17
+
18
+ # Optional sanity check: show the schema
19
+ adapter = TypeAdapter(MultiCommentModel)
20
+ schema = adapter.json_schema()
21
+ print(schema)
src/llm/data_models/typeddict_data_models.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, TypedDict
2
+
3
+
4
+ class SingleCommentModel(TypedDict):
5
+ """Model representing a single comment for a move in a game."""
6
+
7
+ comment_ref: int
8
+ comment: str
9
+
10
+
11
+ class MultiCommentModel(TypedDict):
12
+ """Model representing multiple comments for moves in a game."""
13
+
14
+ comments: List[SingleCommentModel]
15
+
16
+
17
+ # from pydantic import TypeAdapter
18
+
19
+ # # Optional sanity check: show the schema
20
+ # adapter = TypeAdapter(MultiCommentModel)
21
+ # schema = adapter.json_schema()
22
+ # print(schema)
src/llm/prompts/__init__.py ADDED
File without changes
src/llm/prompts/prompt_core.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ prompt_core = """You will be provided with a chess game in PGN format, along with a detailed comments.
2
+ Task:
3
+ 1. Your task is to rewrite the comments based on your role. You need to maintain the original meaning of the comments while adapting them to fit your role.
4
+ 2. Only rewrite the comments, do not change the move numbers or the moves themselves.
5
+ 3. Do not make up comments or variations, only rewrite the provided ones.
6
+ 4. The player who played the move is also provided, use this information to adapt the comments accordingly.
7
+ 5. In your comments, as necessary - specially after a sharp change - indicate overall better/worse position after the move from engine evaluation score after the move. Do not use the engine evaluation score rather interpret it in a human-readable way, e.g. "White is better" or "Black has a slight advantage".
8
+ Do not bother if the score is not significant, e.g. less than 50 centipawns.
9
+ 6. Your output must be a valid json object.
10
+ 7. Your output must only contain the comment_refs and the rewritten comments, without any additional text or explanations. The additonal informations are provided for you to understand the context of the game and the comments.
11
+
12
+ Game in PGN format:
13
+ {pgn}
14
+
15
+ List of comments below. Each comment contain the comment_ref, the move number, the actual move, the player who played the move, a comment, a suggested better variation (can be None), engine evaluated score before the move (in centipawns), engine evaluated score after the move (in centipawns).
16
+ Rememeber, the engine evaluated scores are from the point of view of the white player. Positive scores indicate a favorable position for white, negative scores indicate a favorable position for black.
17
+ {comments}
18
+
19
+
20
+ """
src/llm/prompts/prompt_head.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ prompt_head_expert = """
2
+ You are an expert chess commentator.
3
+ """
4
+ prompt_head_jarvis = """
5
+ You are Jarvis (from IronMan). You will act as a chess commentator. Maintain the robotic Jarvis tone.
6
+ """
7
+
8
+ prompt_head_novice = """You are a novice chess player. You will act as a chess commentator. As a novice player, you are enthusiastic and eager to learn, but you may not fully understand advanced strategies or tactics.
9
+ """
10
+
11
+ prompt_head_natural = """You are a natural chess commentator. You will act as a chess commentator. As a natural commentator, you provide clear and straightforward explanations without using complex jargon or advanced strategies.
12
+ """
13
+
14
+ prompt_head_generic = """Your role: {character}. You will act as a chess commentator. Maintain the {character} tone."""
src/llm/sambanova_wrapper.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ThinkSquare.src.llm.base_llm_wrapper import BaseLLMWrapper
2
+ import os
3
+ from langchain_sambanova import ChatSambaNovaCloud
4
+
5
+ from ThinkSquare.src.llm.data_models.typeddict_data_models import (
6
+ MultiCommentModel,
7
+ )
8
+
9
+
10
+ class SambaNovaWrapper(BaseLLMWrapper):
11
+
12
+ def generate_response(self, prompt: str) -> MultiCommentModel:
13
+ api_key = os.environ.get("SAMBANOVA_API_KEY")
14
+
15
+ llm = ChatSambaNovaCloud(
16
+ api_key=api_key,
17
+ streaming=False,
18
+ model="Meta-Llama-3.3-70B-Instruct",
19
+ max_tokens=8192,
20
+ )
21
+
22
+ llm = llm.with_structured_output(MultiCommentModel)
23
+
24
+ response = llm.invoke(prompt)
25
+
26
+ return response
src/thinksqure_engine.py CHANGED
@@ -5,11 +5,14 @@ import chess.engine
5
  import chess.svg
6
  import chess.pgn
7
 
8
- from src.util.pgn_util import add_variation
 
 
9
 
10
 
11
  class ThinkSquareEngine:
12
  _ENGINE = str(Path("bin/stockfish").resolve())
 
13
 
14
  @staticmethod
15
  def get_best_move(fen: Optional[str] = None, time_limit=0.1):
@@ -58,26 +61,62 @@ class ThinkSquareEngine:
58
  elif eval_drop > 50:
59
  label = "Inaccuracy"
60
  elif eval_drop < -150:
61
- label = "Super Brilliant"
62
- elif eval_drop < -60:
63
  label = "Brilliant"
 
 
64
  elif abs(eval_drop) <= 30:
65
  label = None
66
  else:
67
- label = "Good"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  if label is not None:
70
  comment = f"{label}. "
 
 
 
 
71
  if eval_drop > 0 and engine_best_move_san is not None:
72
  comment += f"Better was {engine_best_move_san} "
73
- played_node.comment = comment
 
74
  if pv is not None:
75
- add_variation(played_node.parent, pv)
 
 
 
 
 
 
 
 
 
76
  else:
77
- played_node.comment = comment
 
 
78
 
79
  @staticmethod
80
- def annotate(game, analysis_time: float = 0.1):
81
 
82
  if not isinstance(game, chess.pgn.Game):
83
  raise ValueError("Input must be a chess.pgn.Game object")
@@ -90,6 +129,17 @@ class ThinkSquareEngine:
90
 
91
  node = game
92
 
 
 
 
 
 
 
 
 
 
 
 
93
  while node.variations:
94
  board = node.board()
95
  played_node = node.variation(0)
@@ -114,17 +164,73 @@ class ThinkSquareEngine:
114
  board.push(played_move)
115
 
116
  if played_move_san != engine_best_move_san:
117
- ThinkSquareEngine._perform_post_analysis_and_add_comment(
118
- analysis_time,
119
- board,
120
- played_node,
121
- pre_eval,
122
- engine_best_move_san,
123
- pv,
 
 
124
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  node = played_node
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  return game
129
 
130
  @staticmethod
@@ -179,7 +285,7 @@ class ThinkSquareEngine:
179
  fen = chess.STARTING_FEN
180
 
181
  board = chess.Board(fen)
182
- svg = chess.svg.board(board=board)
183
  return svg
184
 
185
  @staticmethod
 
5
  import chess.svg
6
  import chess.pgn
7
 
8
+ from ThinkSquare.src.llm.data_models.typeddict_data_models import MultiCommentModel
9
+ from src.llm.sambanova_wrapper import SambaNovaWrapper
10
+ from src.util.pgn_util import add_variation, format_pv
11
 
12
 
13
  class ThinkSquareEngine:
14
  _ENGINE = str(Path("bin/stockfish").resolve())
15
+ llm_commentator = SambaNovaWrapper()
16
 
17
  @staticmethod
18
  def get_best_move(fen: Optional[str] = None, time_limit=0.1):
 
61
  elif eval_drop > 50:
62
  label = "Inaccuracy"
63
  elif eval_drop < -150:
 
 
64
  label = "Brilliant"
65
+ elif eval_drop < -60:
66
+ label = "Very Good"
67
  elif abs(eval_drop) <= 30:
68
  label = None
69
  else:
70
+ label = None # "Good"
71
+
72
+ if post_eval is not None:
73
+ if post_eval > 200:
74
+ overall_situation = "White is better"
75
+ elif post_eval > 100:
76
+ overall_situation = "White has a slight advantage"
77
+ elif post_eval < -200:
78
+ overall_situation = "Black is better"
79
+ elif post_eval < -100:
80
+ overall_situation = "Black has a slight advantage"
81
+ else:
82
+ overall_situation = "No side has a significant advantage"
83
+ else:
84
+ overall_situation = None
85
+
86
+ node_reference = None
87
+ _comment = None
88
+ variation = None
89
+ variation_san = None
90
+ post_eval_score = post_eval
91
 
92
  if label is not None:
93
  comment = f"{label}. "
94
+
95
+ node_reference = played_node
96
+ _comment = comment
97
+
98
  if eval_drop > 0 and engine_best_move_san is not None:
99
  comment += f"Better was {engine_best_move_san} "
100
+ _comment = comment
101
+ # played_node.comment = comment
102
  if pv is not None:
103
+ # add_variation(played_node.parent, pv)
104
+ variation = pv
105
+ variation_san = format_pv(pv, played_node.parent.board())
106
+ else:
107
+ # played_node.comment = comment
108
+ _comment = comment
109
+
110
+ if overall_situation is not None:
111
+ if _comment is not None:
112
+ _comment += f"\n Overall, {overall_situation}."
113
  else:
114
+ _comment = f"Overall, {overall_situation}."
115
+
116
+ return node_reference, _comment, variation, variation_san, post_eval_score
117
 
118
  @staticmethod
119
+ def annotate(game, analysis_time: float = 0.1, llm_character: Optional[str] = None):
120
 
121
  if not isinstance(game, chess.pgn.Game):
122
  raise ValueError("Input must be a chess.pgn.Game object")
 
129
 
130
  node = game
131
 
132
+ comment_refs = []
133
+ node_refs = []
134
+ comments = []
135
+ variations = []
136
+ variation_sans = []
137
+ move_numbers = []
138
+ played_moves = []
139
+ played_by = []
140
+ pre_eval_scores = []
141
+ post_eval_scores = []
142
+
143
  while node.variations:
144
  board = node.board()
145
  played_node = node.variation(0)
 
164
  board.push(played_move)
165
 
166
  if played_move_san != engine_best_move_san:
167
+ node_referece, _comment, variation, variation_san, post_eval_score = (
168
+ ThinkSquareEngine._perform_post_analysis_and_add_comment(
169
+ analysis_time,
170
+ board,
171
+ played_node,
172
+ pre_eval,
173
+ engine_best_move_san,
174
+ pv,
175
+ )
176
  )
177
+ else:
178
+ node_referece = played_node
179
+ _comment = "Best move played."
180
+ variation = None
181
+ variation_san = None
182
+ post_info = ThinkSquareEngine.get_engine_analysis(board, analysis_time)
183
+ post_eval = post_info["score"].white().score(mate_score=100000)
184
+ post_eval_score = post_eval
185
+
186
+ if node_referece is not None:
187
+ node_refs.append(node_referece)
188
+ comments.append(_comment)
189
+ variations.append(variation)
190
+ variation_sans.append(variation_san)
191
+ move_numbers.append(played_node.parent.board().fullmove_number)
192
+ played_moves.append(played_move_san)
193
+ played_by.append(
194
+ "white" if played_node.parent.board().turn else "black"
195
+ )
196
+ pre_eval_scores.append(pre_eval)
197
+ post_eval_scores.append(post_eval_score)
198
+ comment_refs.append(len(comment_refs) + 1)
199
 
200
  node = played_node
201
 
202
+ if llm_character is not None:
203
+ formatted_comments = ThinkSquareEngine.llm_commentator.comment(
204
+ character=llm_character,
205
+ game=str(game),
206
+ comment_refs=comment_refs,
207
+ move_nums=move_numbers,
208
+ comments=comments,
209
+ move_suggestions=variation_sans,
210
+ played_moves=played_moves,
211
+ played_by=played_by,
212
+ pre_eval_scores=pre_eval_scores,
213
+ post_eval_scores=post_eval_scores,
214
+ )
215
+ for comment_with_move_num in formatted_comments["comments"]:
216
+ comment_ref = comment_with_move_num["comment_ref"]
217
+ comment = comment_with_move_num["comment"]
218
+
219
+ if comment_ref not in comment_refs:
220
+ raise ValueError(
221
+ f"Comment reference {comment_ref} not found in comment_refs."
222
+ )
223
+
224
+ index = comment_refs.index(comment_ref)
225
+ comments[index] = comment
226
+
227
+ for node_ref, comment, variation in zip(node_refs, comments, variations):
228
+ if node_ref is None:
229
+ continue
230
+ node_ref.comment = comment
231
+ if variation is not None:
232
+ add_variation(node_ref.parent, variation)
233
+
234
  return game
235
 
236
  @staticmethod
 
285
  fen = chess.STARTING_FEN
286
 
287
  board = chess.Board(fen)
288
+ svg = chess.svg.board(board=board, size=400, coordinates=True)
289
  return svg
290
 
291
  @staticmethod