santimber commited on
Commit
8ae9f7c
·
1 Parent(s): aabbe79
Files changed (2) hide show
  1. agent_tools.py +0 -215
  2. app.py +1 -1
agent_tools.py DELETED
@@ -1,215 +0,0 @@
1
- # %%
2
- from io import BytesIO
3
- import requests
4
- from PIL import Image as PILImage
5
- from transformers import BlipProcessor, BlipForConditionalGeneration
6
- from langchain_core.messages import AnyMessage, HumanMessage, AIMessage
7
- from huggingface_hub import list_models
8
- import random
9
- import pprint
10
- from langchain_community.tools import DuckDuckGoSearchRun
11
- from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
12
- from langgraph.prebuilt import tools_condition
13
- from langgraph.graph import START, StateGraph
14
- from IPython.display import Image, display
15
-
16
- from langgraph.prebuilt import ToolNode
17
- from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, SystemMessage
18
- from langgraph.graph.message import add_messages
19
- from typing import TypedDict, Annotated
20
- from langchain.tools import Tool
21
- from langchain_community.retrievers import BM25Retriever
22
- from langchain.docstore.document import Document
23
- import datasets
24
- from langchain_openai import ChatOpenAI
25
- from dotenv import load_dotenv
26
- import os
27
- import torch
28
- import base64
29
-
30
- # Load environment variables
31
- load_dotenv()
32
-
33
- # DEFINE HUB STAT TOOLS
34
-
35
-
36
- def get_hub_stats(author: str) -> str:
37
- """Fetches the most downloaded model from a specific author on the Hugging Face Hub."""
38
- try:
39
- # List models from the specified author, sorted by downloads
40
- models = list(list_models(
41
- author=author, sort="downloads", direction=-1, limit=1))
42
-
43
- if models:
44
- model = models[0]
45
- return f"The most downloaded model by {author} is {model.id} with {model.downloads:,} downloads."
46
- else:
47
- return f"No models found for author {author}."
48
- except Exception as e:
49
- return f"Error fetching models for {author}: {str(e)}"
50
-
51
-
52
- # Initialize the tool
53
- hub_stats_tool = Tool(
54
- name="get_hub_stats",
55
- func=get_hub_stats,
56
- description="Search HuggingFace Hub for model statistics, downloads, and author information. Use this when asking about specific models, authors, or HuggingFace Hub data."
57
- )
58
-
59
- # DEFINE WEB SEARCH TOOLS
60
- web_search_tool = Tool(
61
- name="search_tool",
62
- func=DuckDuckGoSearchRun(),
63
- description="Search the general web for current information, news, and general knowledge. Use this for questions about companies, people, events, etc."
64
- )
65
-
66
- # REVERSE TOOLS
67
-
68
-
69
- def ReverseTextTool(text: str) -> str:
70
- """Reverses the order of characters in a given text string."""
71
- try:
72
- return text[::-1]
73
- except Exception as e:
74
- return f"Error reversing text: {str(e)}"
75
-
76
-
77
- reverse_text_tool = Tool(
78
- name="reverse_text_tool",
79
- func=ReverseTextTool,
80
- description="Reverses the order of characters in a given text string. Use this when you need to reverse text."
81
- )
82
-
83
- # DOWNLOAD A FILE
84
-
85
-
86
- def download_file(url: str) -> str:
87
- """Downloads a file from a given URL and returns the local file path."""
88
- try:
89
- response = requests.get(url, timeout=30)
90
- response.raise_for_status()
91
-
92
- # Define save_path - extract filename from URL
93
- filename = url.split(
94
- '/')[-1] if url.split('/')[-1] else 'downloaded_file'
95
- save_path = f"./{filename}"
96
-
97
- with open(save_path, "wb") as f:
98
- f.write(response.content)
99
- return save_path
100
- except Exception as e:
101
- return f"Failed to download: {e}"
102
-
103
-
104
- download_file_tool = Tool(
105
- name="download_file_tool",
106
- func=download_file,
107
- description="Downloads a file from a given URL and returns the local file path."
108
- )
109
-
110
- # DEFINE IMAGE RECOGNITION TOOLS
111
-
112
-
113
- def create_vision_llm():
114
- """Creates a vision-capable LLM with proper error handling."""
115
- try:
116
- # Check if OpenAI API key is available
117
- if not os.getenv("OPENAI_API_KEY"):
118
- return None, "OpenAI API key not found. Please set OPENAI_API_KEY in your environment variables."
119
-
120
- vision_llm = ChatOpenAI(model="gpt-4o")
121
- return vision_llm, None
122
- except Exception as e:
123
- return None, f"Error creating vision LLM: {str(e)}"
124
-
125
-
126
- def image_recognition(img_path: str) -> str:
127
- """Analyzes and describes the content of images using AI vision."""
128
- try:
129
- # Check if file exists
130
- if not os.path.exists(img_path):
131
- return f"Error: Image file not found at {img_path}"
132
-
133
- # Create vision LLM
134
- vision_llm, error = create_vision_llm()
135
- if error:
136
- return error
137
-
138
- # Read image and encode as base64
139
- with open(img_path, "rb") as image_file:
140
- image_bytes = image_file.read()
141
-
142
- image_base64 = base64.b64encode(image_bytes).decode("utf-8")
143
-
144
- # Prepare the prompt including the base64 image data
145
- message = [
146
- HumanMessage(
147
- content=[
148
- {
149
- "type": "text",
150
- "text": (
151
- "Describe the image or extract all the text from this image. "
152
- "Return only the description or extracted text, no explanations."
153
- ),
154
- },
155
- {
156
- "type": "image_url",
157
- "image_url": {
158
- "url": f"data:image/png;base64,{image_base64}"
159
- },
160
- },
161
- ]
162
- )
163
- ]
164
-
165
- # Call the vision-capable model
166
- response = vision_llm.invoke(message)
167
- return response.content.strip()
168
-
169
- except Exception as e:
170
- return f"Error analyzing image: {str(e)}"
171
-
172
-
173
- image_recognition_tool = Tool(
174
- name="image_recognition_tool",
175
- func=image_recognition,
176
- description="Analyzes and describes the content of images using AI vision. Use this when you need to understand what's in an image."
177
- )
178
-
179
- # Test functions (commented out to avoid side effects)
180
-
181
-
182
- def test_tools():
183
- """Test all tools to ensure they work properly."""
184
- print("Testing Hub Stats Tool:")
185
- print(hub_stats_tool.invoke("google"))
186
- print("\n" + "="*50 + "\n")
187
-
188
- print("Testing Web Search Tool:")
189
- results = web_search_tool.invoke("what is the matrix?")
190
- pp = pprint.PrettyPrinter()
191
- print(pp.pprint(results))
192
- print("\n" + "="*50 + "\n")
193
-
194
- print("Testing Reverse Text Tool:")
195
- results = reverse_text_tool.invoke("what is the matrix?")
196
- print(results)
197
- print("\n" + "="*50 + "\n")
198
-
199
- print("Testing Download File Tool:")
200
- test_url = "https://www.google.com"
201
- results = download_file_tool.invoke(test_url)
202
- print(results)
203
- print("\n" + "="*50 + "\n")
204
-
205
- print("Testing Image Recognition Tool:")
206
- test_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/1200px-Cat03.jpg"
207
- downloaded_file = download_file_tool.invoke(test_url)
208
- if not downloaded_file.startswith("Failed"):
209
- results = image_recognition_tool.invoke(downloaded_file)
210
- print(results)
211
- else:
212
- print("Skipping image recognition test due to download failure")
213
-
214
- # Uncomment the line below to run tests
215
- # test_tools()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -9,7 +9,7 @@ from langgraph.graph import START, StateGraph
9
  from langgraph.prebuilt import ToolNode, tools_condition
10
  from langgraph.graph.message import add_messages
11
  from typing import TypedDict, Annotated
12
- from agent_tools import image_recognition_tool, download_file_tool, reverse_text_tool, hub_stats_tool, web_search_tool
13
 
14
  # (Keep Constants as is)
15
  # --- Constants ---
 
9
  from langgraph.prebuilt import ToolNode, tools_condition
10
  from langgraph.graph.message import add_messages
11
  from typing import TypedDict, Annotated
12
+ from tools import image_recognition_tool, download_file_tool, reverse_text_tool, hub_stats_tool, web_search_tool
13
 
14
  # (Keep Constants as is)
15
  # --- Constants ---