santimber commited on
Commit
b39a153
·
1 Parent(s): 81917a3

new tools

Browse files
Files changed (1) hide show
  1. tools.py +172 -0
tools.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ import requests
3
+ from PIL import Image as PILImage
4
+ from transformers import BlipProcessor, BlipForConditionalGeneration
5
+ from langchain_core.messages import AnyMessage, HumanMessage, AIMessage
6
+ from huggingface_hub import list_models
7
+ import random
8
+ import pprint
9
+ from langchain_community.tools import DuckDuckGoSearchRun
10
+ from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
11
+ from langgraph.prebuilt import tools_condition
12
+ from langgraph.graph import START, StateGraph
13
+ from IPython.display import Image, display
14
+
15
+ from langgraph.prebuilt import ToolNode
16
+ from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, SystemMessage
17
+ from langgraph.graph.message import add_messages
18
+ from typing import TypedDict, Annotated
19
+ from langchain.tools import Tool
20
+ from langchain_community.retrievers import BM25Retriever
21
+ from langchain.docstore.document import Document
22
+ import datasets
23
+ from langchain_openai import ChatOpenAI
24
+ from dotenv import load_dotenv
25
+ import os
26
+ import torch
27
+ import base64
28
+
29
+ # DEFINE HUB STAT TOOLS
30
+
31
+
32
+ def get_hub_stats(author: str) -> str:
33
+ """Fetches the most downloaded model from a specific author on the Hugging Face Hub."""
34
+ try:
35
+ # List models from the specified author, sorted by downloads
36
+ models = list(list_models(
37
+ author=author, sort="downloads", direction=-1, limit=1))
38
+
39
+ if models:
40
+ model = models[0]
41
+ return f"The most downloaded model by {author} is {model.id} with {model.downloads:,} downloads."
42
+ else:
43
+ return f"No models found for author {author}."
44
+ except Exception as e:
45
+ return f"Error fetching models for {author}: {str(e)}"
46
+
47
+
48
+ # Initialize the tool
49
+ hub_stats_tool = Tool(
50
+ name="get_hub_stats",
51
+ func=get_hub_stats,
52
+ description="Search HuggingFace Hub for model statistics, downloads, and author information. Use this when asking about specific models, authors, or HuggingFace Hub data."
53
+ )
54
+
55
+ print(hub_stats_tool.invoke("google"))
56
+
57
+ # DEFINE WEB SEARCH TOOLS
58
+ web_search_tool = Tool(
59
+ name="search_tool",
60
+ func=DuckDuckGoSearchRun(),
61
+ description="Search the general web for current information, news, and general knowledge. Use this for questions about companies, people, events, etc."
62
+ )
63
+
64
+ results = web_search_tool.invoke("what is the matrix?")
65
+ pp = pprint.PrettyPrinter()
66
+ print(pp.pprint(results))
67
+
68
+
69
+ # REVERSE TOOLS
70
+ def ReverseTextTool(text: str) -> str:
71
+ return text[::-1]
72
+
73
+
74
+ reverse_text_tool = Tool(
75
+ name="reverse_text_tool",
76
+ func=ReverseTextTool,
77
+ description="Reverses the order of characters in a given text string. Use this when you need to reverse text."
78
+ )
79
+
80
+ results = reverse_text_tool.invoke("what is the matrix?")
81
+ pp = pprint.PrettyPrinter()
82
+ print(pp.pprint(results))
83
+
84
+ # DOWNLOAD A FILE
85
+
86
+
87
+ def download_file(url: str) -> str:
88
+ try:
89
+ response = requests.get(url, timeout=30)
90
+ response.raise_for_status()
91
+
92
+ # Define save_path - extract filename from URL
93
+ filename = url.split(
94
+ '/')[-1] if url.split('/')[-1] else 'downloaded_file'
95
+ save_path = f"./{filename}"
96
+
97
+ with open(save_path, "wb") as f:
98
+ f.write(response.content)
99
+ return save_path # Return the file path instead of success message
100
+ except Exception as e:
101
+ return f"Failed to download: {e}"
102
+
103
+
104
+ download_file_tool = Tool(
105
+ name="download_file_tool",
106
+ func=download_file,
107
+ description="Downloads a file from a given URL."
108
+ )
109
+
110
+ results = download_file_tool.invoke("https://www.google.com")
111
+ print(results)
112
+
113
+ # DEFINE IMAGE RECOGNITION TOOLS
114
+ vision_llm = ChatOpenAI(model="gpt-4o")
115
+
116
+
117
+ def image_recognition(img_path: str) -> str:
118
+
119
+ all_text = ""
120
+ try:
121
+ # Read image and encode as base64
122
+ with open(img_path, "rb") as image_file:
123
+ image_bytes = image_file.read()
124
+
125
+ image_base64 = base64.b64encode(image_bytes).decode("utf-8")
126
+
127
+ # Prepare the prompt including the base64 image data
128
+ message = [
129
+ HumanMessage(
130
+ content=[
131
+ {
132
+ "type": "text",
133
+ "text": (
134
+ "Describe the image or extract all the text from this image. "
135
+ "Return only the description orextracted text, no explanations."
136
+ ),
137
+ },
138
+ {
139
+ "type": "image_url",
140
+ "image_url": {
141
+ "url": f"data:image/png;base64,{image_base64}"
142
+ },
143
+ },
144
+ ]
145
+ )
146
+ ]
147
+
148
+ # Call the vision-capable model
149
+ response = vision_llm.invoke(message)
150
+
151
+ # Append extracted text
152
+ all_text += response.content + "\n\n"
153
+
154
+ return all_text.strip()
155
+ except Exception as e:
156
+ # A butler should handle errors gracefully
157
+ error_msg = f"Error extracting text: {str(e)}"
158
+ print(error_msg)
159
+ return ""
160
+
161
+
162
+ image_recognition_tool = Tool(
163
+ name="image_recognition_tool",
164
+ func=image_recognition,
165
+ description="Analyzes and describes the content of images using AI vision. Use this when you need to understand what's in an image."
166
+ )
167
+ test_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/1200px-Cat03.jpg"
168
+
169
+ results = image_recognition_tool.invoke(download_file_tool.invoke(test_url))
170
+ print(results)
171
+
172
+ #