jinysun commited on
Commit
5e0c30f
·
verified ·
1 Parent(s): 7854f0d

Update tool/ImageAnalysis.py

Browse files
Files changed (1) hide show
  1. tool/ImageAnalysis.py +67 -67
tool/ImageAnalysis.py CHANGED
@@ -1,68 +1,68 @@
1
- # -*- coding: utf-8 -*-
2
- """
3
- Created on Sat Oct 26 15:35:19 2024
4
-
5
- @author: BM109X32G-10GPU-02
6
- """
7
-
8
- from langchain_community.embeddings import OllamaEmbeddings
9
- from langchain.tools import BaseTool
10
- from langchain_openai import ChatOpenAI
11
- from langchain_core.messages import HumanMessage, SystemMessage
12
- from langchain.base_language import BaseLanguageModel
13
- import base64
14
- from io import BytesIO
15
- from PIL import Image
16
-
17
-
18
-
19
- def convert_to_base64(pil_image):
20
- buffered = BytesIO()
21
- pil_image.save(buffered, format="PNG")
22
- img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
23
- return img_str
24
-
25
-
26
- class Imageanalysis(BaseTool):
27
- name: str = "Imageanalysis"
28
- description: str = (
29
- "Useful to answer questions according to the image, figure, diagram or graph. "
30
- "Useful to analysis the information in the image, figure, diagram or graph. "
31
- "Input query about image/figure/graph/diagram, return the response"
32
- )
33
- return_direct: bool = True
34
- llm: BaseLanguageModel = None
35
- path : str = None
36
-
37
- def __init__(self, path):
38
- super().__init__( )
39
- self.llm = ChatOpenAI(model="gpt-4o-2024-11-20",api_key='sk-itPrztYm9F6XZZpsBMJB9O7Vq0pYUABVVBSoThuBxEGTnDik',
40
- base_url="https://www.dmxapi.com/v1")
41
- self.path = path
42
- # api keys
43
-
44
- def _run(self, query ) -> str:
45
- try:
46
- pil_image = Image.open(self.path)
47
- rgb_im = pil_image.convert('RGB')
48
- image_b64 = convert_to_base64(pil_image)
49
- message = HumanMessage(
50
- content=[
51
- {"type": "text", "text": query},
52
- {
53
- "type": "image_url",
54
- "image_url": {"url":f"data:image/jpeg;base64,{image_b64}"},
55
- },
56
- ],)
57
- response = self.llm.invoke([message])
58
- return response.content
59
-
60
- except Exception as e:
61
- return str(e)
62
-
63
-
64
- async def _arun(self, query) -> str:
65
- """Use the tool asynchronously."""
66
- raise NotImplementedError("this tool does not support async")
67
-
68
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Sat Oct 26 15:35:19 2024
4
+
5
+ @author: BM109X32G-10GPU-02
6
+ """
7
+
8
+ from langchain_community.embeddings import OllamaEmbeddings
9
+ from langchain.tools import BaseTool
10
+ from langchain_openai import ChatOpenAI
11
+ from langchain_core.messages import HumanMessage, SystemMessage
12
+ from langchain.base_language import BaseLanguageModel
13
+ import base64
14
+ from io import BytesIO
15
+ from PIL import Image
16
+
17
+
18
+
19
+ def convert_to_base64(pil_image):
20
+ buffered = BytesIO()
21
+ pil_image.save(buffered, format="PNG")
22
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
23
+ return img_str
24
+
25
+
26
+ class Imageanalysis(BaseTool):
27
+ name: str = "Imageanalysis"
28
+ description: str = (
29
+ "Useful to answer questions according to the image, figure, diagram or graph. "
30
+ "Useful to analysis the information in the image, figure, diagram or graph. "
31
+ "Input query about image/figure/graph/diagram, return the response"
32
+ )
33
+ return_direct: bool = True
34
+ llm: BaseLanguageModel = None
35
+ path : str = None
36
+
37
+ def __init__(self, path):
38
+ super().__init__( )
39
+ self.llm = ChatOpenAI(model="gpt-4o-2024-11-20",api_key=os.getenv("OPENAI_API_KEY"),
40
+ base_url=os.getenv("OPENAI_API_BASE"))
41
+ self.path = path
42
+ # api keys
43
+
44
+ def _run(self, query ) -> str:
45
+ try:
46
+ pil_image = Image.open(self.path)
47
+ rgb_im = pil_image.convert('RGB')
48
+ image_b64 = convert_to_base64(pil_image)
49
+ message = HumanMessage(
50
+ content=[
51
+ {"type": "text", "text": query},
52
+ {
53
+ "type": "image_url",
54
+ "image_url": {"url":f"data:image/jpeg;base64,{image_b64}"},
55
+ },
56
+ ],)
57
+ response = self.llm.invoke([message])
58
+ return response.content
59
+
60
+ except Exception as e:
61
+ return str(e)
62
+
63
+
64
+ async def _arun(self, query) -> str:
65
+ """Use the tool asynchronously."""
66
+ raise NotImplementedError("this tool does not support async")
67
+
68