File size: 3,884 Bytes
0a55bd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from dotenv import load_dotenv

load_dotenv()

use_mistral = False
use_gemini = True

import base64
import requests
import os
from mistralai import Mistral

def encode_image(image_path):
    """Encode the image to base64."""
    try:
        with open(image_path, "rb") as image_file:
            return base64.b64encode(image_file.read()).decode('utf-8')
    except FileNotFoundError:
        print(f"Error: The file {image_path} was not found.")
        return None
    except Exception as e:  # Added general exception handling
        print(f"Error: {e}")
        return None

# Path to your image
image_path = "chess_test.jpg"

# Getting the base64 string
base64_image = encode_image(image_path)

# Retrieve the API key from environment variables
api_key = os.environ.get("API_KEY_MISTRAL")

# Specify model
model = "pixtral-large-latest"

# Initialize the Mistral client
client = Mistral(api_key=api_key)

# Define the messages for the chat
messages = [
    {
        "role": "user",
        "content": [
            {
                "type": "text",
                "text": r"""Below is an image of a chess board mid-game. Only use the image as a reference for the response. NEVER use implicit knowledge of chess or positions. 
                The bottom left square is A1, the top right square is H8.
                
                Identify the position of all pieces in JSON format: {colour:{piece_type:[coordinates]}}
                
                Chess board diagram:"""
            },
            {
                "type": "image_url",
                "image_url": f"data:image/jpeg;base64,{base64_image}"
            }
        ]
    }
]

if use_mistral:
    # Get the chat response
    chat_response = client.chat.complete(
        model=model,
        messages=messages
    )

    # Print the content of the response
    print(chat_response.choices[0].message.content)
    
#### Gemini

from google import genai
from google.genai import types

# Only run this block for Gemini Developer API

client = genai.Client(api_key=os.environ.get("API_KEY_GEMINI2"))
flash = True

if flash:
    google_model = 'gemini-2.5-flash-preview-05-20'
else:
    google_model = 'gemini-2.5-pro-preview-05-06'

chess_prompt = """Using this image of a chess board diagram. Black squares are coloured dark brown, white squares are light brown. A1 is at the bottom left of the image, H8 is at the top right. Complete the following tasks in order:
Task 1: Count the number of occupied and unoccupied squares in each row. e.g. {'occupied':3, 'unoccupied':5} => STRING
Task 2: Count the number of each piece type in each row. Check that they add up to the total number of pieces. => STRING
Task 3: In JSON format note the position of every piece by colour, type and then list of coordinates => JSON
Task 4: Convert JSON format to FEN string. {'board_fen': <FEN STRING>} => JSON"""


# To run this code you need to install the following dependencies:
# pip install google-genai

import base64
import os
from google import genai
from google.genai import types

def generate():
    client = genai.Client(
        api_key=os.environ.get("GEMINI_API_KEY"),
    )

    model = google_model #"gemini-2.5-pro-preview-05-06"
    contents = [
        types.Content(
            role="user",
            parts=[
                types.Part.from_bytes(
                    mime_type="image/jpeg",
                    data=base64_image,
                ),
                types.Part.from_text(text=chess_prompt),
            ],
        ),
    ]
    generate_content_config = types.GenerateContentConfig(
        temperature=0.15,
        response_mime_type="text/plain",
    )

    for chunk in client.models.generate_content_stream(
        model=model,
        contents=contents,
        config=generate_content_config,
    ):
        print(chunk.text, end="")

if __name__ == "__main__":
    generate()