lailaelkoussy commited on
Commit
211037a
Β·
1 Parent(s): 44bbb55

first commit

Browse files
Files changed (3) hide show
  1. README.md +3 -1
  2. requirements.txt +255 -0
  3. smolagent_chat.py +551 -0
README.md CHANGED
@@ -3,8 +3,10 @@ title: Transformers Library QA Agent
3
  emoji: 🐨
4
  colorFrom: pink
5
  colorTo: red
6
- sdk: docker
7
  pinned: false
 
 
8
  ---
9
 
10
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
3
  emoji: 🐨
4
  colorFrom: pink
5
  colorTo: red
6
+ sdk: gradio
7
  pinned: false
8
+ app_file: smolagent_chat.py
9
+ python_version: 3.13
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
requirements.txt ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==24.1.0
2
+ aiohappyeyeballs==2.6.1
3
+ aiohttp==3.13.2
4
+ aioitertools==0.12.0
5
+ aiosignal==1.4.0
6
+ aiosqlite==0.21.0
7
+ alembic==1.17.1
8
+ annotated-doc==0.0.3
9
+ annotated-types==0.7.0
10
+ anyio==4.11.0
11
+ arize-phoenix==12.9.0
12
+ arize-phoenix-client==1.21.0
13
+ arize-phoenix-evals==2.5.0
14
+ arize-phoenix-otel==0.13.1
15
+ attrs==25.4.0
16
+ audioop-lts==0.2.2
17
+ Authlib==1.6.5
18
+ backoff==2.2.1
19
+ bcrypt==5.0.0
20
+ beartype==0.22.5
21
+ beautifulsoup4==4.14.2
22
+ brotli==1.2.0
23
+ bs4==0.0.2
24
+ build==1.3.0
25
+ cachetools==6.2.1
26
+ certifi==2025.10.5
27
+ cffi==2.0.0
28
+ charset-normalizer==3.4.4
29
+ chromadb==1.3.0
30
+ clang==20.1.5
31
+ click==8.3.0
32
+ coloredlogs==15.0.1
33
+ contourpy==1.3.3
34
+ cryptography==46.0.3
35
+ cycler==0.12.1
36
+ cyclopts==4.2.1
37
+ datasets==4.4.1
38
+ ddgs==9.7.1
39
+ deprecation==2.1.0
40
+ dill==0.4.0
41
+ diskcache==5.6.3
42
+ distro==1.9.0
43
+ dnspython==2.8.0
44
+ docstring_parser==0.17.0
45
+ docutils==0.22.2
46
+ durationpy==0.10
47
+ email-validator==2.3.0
48
+ esprima==4.0.1
49
+ exceptiongroup==1.3.0
50
+ fastapi==0.121.0
51
+ ffmpy==1.0.0
52
+ filelock==3.20.0
53
+ flatbuffers==25.9.23
54
+ fonttools==4.60.1
55
+ frozenlist==1.8.0
56
+ fsspec==2025.9.0
57
+ google-auth==2.42.0
58
+ googleapis-common-protos==1.71.0
59
+ gradio==5.49.1
60
+ gradio_client==1.13.3
61
+ graphql-core==3.2.7
62
+ greenlet==3.2.4
63
+ groovy==0.1.2
64
+ grpc-interceptor==0.15.4
65
+ grpcio==1.76.0
66
+ h11==0.16.0
67
+ h2==4.3.0
68
+ hf-xet==1.2.0
69
+ hpack==4.1.0
70
+ httpcore==1.0.9
71
+ httptools==0.7.1
72
+ httpx==0.28.1
73
+ httpx-sse==0.4.3
74
+ huggingface-hub==0.36.0
75
+ humanfriendly==10.0
76
+ hyperframe==6.1.0
77
+ idna==3.11
78
+ importlib_metadata==8.7.0
79
+ importlib_resources==6.5.2
80
+ iniconfig==2.3.0
81
+ jaraco.classes==3.4.0
82
+ jaraco.context==6.0.1
83
+ jaraco.functools==4.3.0
84
+ javalang==0.13.0
85
+ jeepney==0.9.0
86
+ Jinja2==3.1.6
87
+ jiter==0.11.1
88
+ jmespath==1.0.1
89
+ joblib==1.5.2
90
+ jsonpatch==1.33
91
+ jsonpath-ng==1.7.0
92
+ jsonpointer==3.0.0
93
+ jsonref==1.1.0
94
+ jsonschema==4.25.1
95
+ jsonschema-path==0.3.4
96
+ jsonschema-specifications==2025.9.1
97
+ keyring==25.6.0
98
+ kiwisolver==1.4.9
99
+ kubernetes==34.1.0
100
+ lance-namespace==0.0.21
101
+ lance-namespace-urllib3-client==0.0.21
102
+ lancedb==0.25.3
103
+ langchain-core==1.0.1
104
+ langchain-text-splitters==1.0.0
105
+ langfuse==3.9.0
106
+ langsmith==0.4.38
107
+ lxml==6.0.2
108
+ Mako==1.3.10
109
+ markdown-it-py==4.0.0
110
+ MarkupSafe==3.0.3
111
+ matplotlib==3.10.7
112
+ mcp==1.10.1
113
+ mcpadapt==0.1.20
114
+ mdurl==0.1.2
115
+ mmh3==5.2.0
116
+ more-itertools==10.8.0
117
+ mpmath==1.3.0
118
+ multidict==6.7.0
119
+ multiprocess==0.70.18
120
+ nest-asyncio==1.6.0
121
+ networkx==3.5
122
+ numpy==2.3.4
123
+ nvidia-cublas-cu12==12.8.4.1
124
+ nvidia-cuda-cupti-cu12==12.8.90
125
+ nvidia-cuda-nvrtc-cu12==12.8.93
126
+ nvidia-cuda-runtime-cu12==12.8.90
127
+ nvidia-cudnn-cu12==9.10.2.21
128
+ nvidia-cufft-cu12==11.3.3.83
129
+ nvidia-cufile-cu12==1.13.1.3
130
+ nvidia-curand-cu12==10.3.9.90
131
+ nvidia-cusolver-cu12==11.7.3.90
132
+ nvidia-cusparse-cu12==12.5.8.93
133
+ nvidia-cusparselt-cu12==0.7.1
134
+ nvidia-nccl-cu12==2.27.5
135
+ nvidia-nvjitlink-cu12==12.8.93
136
+ nvidia-nvshmem-cu12==3.3.20
137
+ nvidia-nvtx-cu12==12.8.90
138
+ oauthlib==3.3.1
139
+ onnxruntime==1.23.2
140
+ openai==2.6.1
141
+ openapi-pydantic==0.5.1
142
+ openinference-instrumentation==0.1.42
143
+ openinference-instrumentation-smolagents==0.1.19
144
+ openinference-semantic-conventions==0.1.25
145
+ opentelemetry-api==1.38.0
146
+ opentelemetry-exporter-otlp==1.38.0
147
+ opentelemetry-exporter-otlp-proto-common==1.38.0
148
+ opentelemetry-exporter-otlp-proto-grpc==1.38.0
149
+ opentelemetry-exporter-otlp-proto-http==1.38.0
150
+ opentelemetry-instrumentation==0.59b0
151
+ opentelemetry-proto==1.38.0
152
+ opentelemetry-sdk==1.38.0
153
+ opentelemetry-semantic-conventions==0.59b0
154
+ orjson==3.11.4
155
+ overrides==7.7.0
156
+ packaging==25.0
157
+ pandas==2.3.3
158
+ pathable==0.4.4
159
+ pathvalidate==3.3.1
160
+ pillow==11.3.0
161
+ pip==25.3
162
+ platformdirs==4.5.0
163
+ pluggy==1.6.0
164
+ ply==3.11
165
+ posthog==5.4.0
166
+ primp==0.15.0
167
+ prometheus_client==0.23.1
168
+ propcache==0.4.1
169
+ protobuf==6.33.0
170
+ psutil==7.1.3
171
+ py-key-value-aio==0.2.8
172
+ py-key-value-shared==0.2.8
173
+ pyarrow==22.0.0
174
+ pyasn1==0.6.1
175
+ pyasn1_modules==0.4.2
176
+ pybase64==1.4.2
177
+ pycparser==2.23
178
+ pydantic==2.11.10
179
+ pydantic_core==2.33.2
180
+ pydantic-settings==2.11.0
181
+ pydub==0.25.1
182
+ Pygments==2.19.2
183
+ PyJWT==2.10.1
184
+ pylance==0.39.0
185
+ pyparsing==3.2.5
186
+ pyperclip==1.11.0
187
+ PyPika==0.48.9
188
+ pyproject_hooks==1.2.0
189
+ pystache==0.6.8
190
+ pytest==8.4.2
191
+ python-dateutil==2.9.0.post0
192
+ python-dotenv==1.2.1
193
+ python-multipart==0.0.20
194
+ pytz==2025.2
195
+ PyYAML==6.0.3
196
+ referencing==0.36.2
197
+ regex==2025.10.23
198
+ requests==2.32.5
199
+ requests-oauthlib==2.0.0
200
+ requests-toolbelt==1.0.0
201
+ rich==14.2.0
202
+ rich-rst==1.3.2
203
+ rpds-py==0.28.0
204
+ rsa==4.9.1
205
+ ruff==0.14.5
206
+ safehttpx==0.1.7
207
+ safetensors==0.6.2
208
+ scikit-learn==1.7.2
209
+ scipy==1.16.3
210
+ SecretStorage==3.4.0
211
+ semantic-version==2.10.0
212
+ sentence-transformers==5.1.2
213
+ setuptools==80.9.0
214
+ shellingham==1.5.4
215
+ six==1.17.0
216
+ smolagents==1.22.0
217
+ sniffio==1.3.1
218
+ socksio==1.0.0
219
+ soupsieve==2.8
220
+ SQLAlchemy==2.0.44
221
+ sqlean.py==3.50.4.5
222
+ sse-starlette==3.0.3
223
+ starlette==0.49.3
224
+ strawberry-graphql==0.270.1
225
+ sympy==1.14.0
226
+ tantivy==0.25.0
227
+ tenacity==9.1.2
228
+ threadpoolctl==3.6.0
229
+ tokenizers==0.22.1
230
+ tomlkit==0.13.3
231
+ torch==2.9.0
232
+ tqdm==4.67.1
233
+ transformers==4.57.1
234
+ tree-sitter==0.25.2
235
+ tree-sitter-rust==0.24.0
236
+ triton==3.5.0
237
+ typer==0.20.0
238
+ typer-slim==0.20.0
239
+ typing_extensions==4.15.0
240
+ typing-inspection==0.4.2
241
+ tzdata==2025.2
242
+ urllib3==2.3.0
243
+ uvicorn==0.38.0
244
+ uvloop==0.22.1
245
+ validators==0.35.0
246
+ watchfiles==1.1.1
247
+ weaviate-client==4.17.0
248
+ websocket-client==1.9.0
249
+ websockets==15.0.1
250
+ wheel==0.45.1
251
+ wrapt==1.17.3
252
+ xxhash==3.6.0
253
+ yarl==1.22.0
254
+ zipp==3.23.0
255
+ zstandard==0.25.0
smolagent_chat.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Smolagents Agent with Gradio Chat Interface that connects to the MCP server.
4
+ This script creates an interactive chat interface where users can query the knowledge graph
5
+ through a conversational AI agent.
6
+ """
7
+
8
+ import os
9
+ import sys
10
+ import argparse
11
+ import re
12
+ from typing import List, Dict, Any
13
+ import gradio as gr
14
+ from gradio import ChatMessage
15
+ from smolagents import MCPClient, ToolCallingAgent, OpenAIServerModel, AzureOpenAIModel, InferenceClientModel, stream_to_gradio
16
+
17
+
18
+ class Colors:
19
+ """Color codes for terminal output."""
20
+ GREEN = '\033[92m'
21
+ RED = '\033[91m'
22
+ YELLOW = '\033[93m'
23
+ CYAN = '\033[96m'
24
+ ENDC = '\033[0m'
25
+ BOLD = '\033[1m'
26
+
27
+
28
+ def print_success(message):
29
+ print(f"{Colors.GREEN}βœ“ {message}{Colors.ENDC}")
30
+
31
+
32
+ def print_error(message):
33
+ print(f"{Colors.RED}❌ {message}{Colors.ENDC}")
34
+
35
+
36
+ def print_info(message):
37
+ print(f"{Colors.YELLOW}ℹ️ {message}{Colors.ENDC}")
38
+
39
+
40
+ CUSTOM_INSTRUCTIONS = """You are an expert assistant for understanding the Hugging Face Transformers library.
41
+
42
+ Your role is to help users understand the Transformers codebase by exploring the repository using the available tools. You can:
43
+ - Search for functions, classes, and methods in the codebase
44
+ - Navigate the file structure and understand code organization
45
+ - Find relationships between different components
46
+ - Trace how code flows through the library
47
+ - Explain implementation details and design patterns
48
+
49
+ When answering questions:
50
+ 1. Use the available tools to explore the repository and gather accurate information
51
+ 2. Provide clear, well-structured explanations based on the actual code
52
+ 3. Reference specific files, functions, or classes when relevant
53
+ 4. If you're unsure about something, search the codebase to verify before answering
54
+
55
+ Always base your answers on the actual code in the repository, not assumptions."""
56
+
57
+
58
+ class KnowledgeGraphChatAgent:
59
+ """A chat agent that connects to the Knowledge Graph MCP server."""
60
+
61
+ def __init__(self, mcp_server_url: str = None):
62
+ """Initialize the chat agent with MCP server connection."""
63
+ self.mcp_server_url = mcp_server_url or os.getenv("MCP_SERVER_URL", "http://localhost:4000/mcp")
64
+ self.model = None
65
+ self.agent = None
66
+ self.mcp_client = None
67
+ self.tools = None
68
+ self.conversation_history = []
69
+
70
+ # Initialize MCP tools first (required for agent)
71
+ self._initialize_mcp_tools()
72
+
73
+ def _initialize_mcp_tools(self):
74
+ """Initialize MCP client and load tools (must be done before agent creation)."""
75
+ try:
76
+ print_info(f"Connecting to MCP server at {self.mcp_server_url}...")
77
+ self.mcp_client = MCPClient({"url": self.mcp_server_url, "transport": "streamable-http"})
78
+ self.tools = self.mcp_client.__enter__()
79
+ print_success(f"MCP tools loaded successfully! ({len(self.tools)} tools available)")
80
+ except Exception as e:
81
+ print_error(f"Failed to connect to MCP server: {e}")
82
+ raise
83
+
84
+ def _initialize_model(self, model_type: str = "openai", api_key: str = None,
85
+ base_url: str = None, model_name: str = None,
86
+ api_version: str = None):
87
+ """Initialize the OpenAI, Azure OpenAI, or HF Inference model with provided configuration."""
88
+
89
+ print_info(f"Initializing model:")
90
+ print(f" Model Type: {model_type}")
91
+ print(f" Model: {model_name}")
92
+
93
+ try:
94
+ if model_type == "azure":
95
+ api_key = api_key or os.environ.get('AZURE_OPENAI_API_KEY')
96
+ base_url = base_url or os.environ.get('AZURE_OPENAI_ENDPOINT')
97
+ api_version = api_version or os.environ.get('OPENAI_API_VERSION', '2024-02-15-preview')
98
+
99
+ if not api_key:
100
+ raise ValueError("Azure API key is required!")
101
+ if not base_url:
102
+ raise ValueError("Azure endpoint is required!")
103
+
104
+ print(f" Endpoint: {base_url}")
105
+ print(f" API Version: {api_version}")
106
+
107
+ self.model = AzureOpenAIModel(
108
+ model_id=model_name,
109
+ azure_endpoint=base_url,
110
+ api_key=api_key,
111
+ api_version=api_version
112
+ )
113
+ elif model_type == "hf_inference":
114
+ api_key = api_key or os.environ.get('HF_TOKEN')
115
+ model_name = model_name or os.environ.get('HF_MODEL_NAME', 'Qwen/Qwen2.5-Coder-32B-Instruct')
116
+ provider = base_url or os.environ.get('HF_INFERENCE_PROVIDER', '')
117
+
118
+ if not api_key:
119
+ raise ValueError("HuggingFace token is required!")
120
+
121
+ print(f" Model: {model_name}")
122
+ if provider:
123
+ print(f" Provider: {provider}")
124
+
125
+ # Build kwargs for InferenceClientModel
126
+ model_kwargs = {
127
+ "model_id": model_name,
128
+ "token": api_key,
129
+ "bill_to": "epita"
130
+ }
131
+ if provider:
132
+ model_kwargs["provider"] = provider
133
+
134
+ self.model = InferenceClientModel(**model_kwargs)
135
+ else: # openai
136
+ api_key = api_key or os.environ.get('OPENAI_API_KEY')
137
+ base_url = base_url or os.environ.get('OPENAI_BASE_URL', 'https://api.openai.com/v1')
138
+ model_name = model_name or os.environ.get('OPENAI_MODEL_NAME', 'gpt-4o-mini')
139
+
140
+ if not api_key:
141
+ raise ValueError("OpenAI API key is required!")
142
+
143
+ print(f" Base URL: {base_url}")
144
+
145
+ self.model = OpenAIServerModel(
146
+ model_id=model_name,
147
+ api_key=api_key,
148
+ api_base=base_url
149
+ )
150
+ print_success("Model initialized successfully!")
151
+ except Exception as e:
152
+ print_error(f"Failed to initialize model: {e}")
153
+ raise
154
+
155
+ def _initialize_agent(self, max_steps: int = None):
156
+ """Initialize the agent using the configured model and pre-loaded MCP tools."""
157
+ if not self.model:
158
+ raise ValueError("Model must be initialized before creating agent!")
159
+ if not self.tools:
160
+ raise ValueError("MCP tools must be loaded before creating agent!")
161
+
162
+ try:
163
+ max_steps = max_steps or int(os.getenv("MAX_STEPS", 5))
164
+
165
+ self.agent = ToolCallingAgent(
166
+ tools=self.tools,
167
+ model=self.model,
168
+ name="KnowledgeGraphAgent",
169
+ max_steps=max_steps,
170
+ add_base_tools=False,
171
+ instructions=CUSTOM_INSTRUCTIONS
172
+ )
173
+ print_success("Agent initialized successfully!")
174
+ except Exception as e:
175
+ print_error(f"Failed to initialize agent: {e}")
176
+ raise
177
+
178
+ def is_ready(self):
179
+ """Check if the agent is fully initialized and ready to chat."""
180
+ return self.agent is not None and self.model is not None
181
+
182
+ def _parse_thinking_tags(self, text: str):
183
+ """
184
+ Extract content from <think> tags and return both thinking content and clean text.
185
+
186
+ Args:
187
+ text: Text that may contain <think>...</think> tags
188
+
189
+ Returns:
190
+ tuple: (thinking_content, clean_text)
191
+ """
192
+ # Find all <think>...</think> blocks
193
+ think_pattern = r'<think>(.*?)</think>'
194
+ thoughts = re.findall(think_pattern, text, re.DOTALL)
195
+
196
+ # Remove <think> tags from the text
197
+ clean_text = re.sub(think_pattern, '', text, flags=re.DOTALL).strip()
198
+
199
+ return thoughts, clean_text
200
+
201
+ def chat(self, message: str, history: List[Dict[str, Any]]):
202
+ """
203
+ Process a chat message and stream the response using messages format.
204
+
205
+ Args:
206
+ message: The user's message
207
+ history: The conversation history as list of message dictionaries
208
+
209
+ Yields:
210
+ Updated history with new messages including thinking and tool usage
211
+ """
212
+ if not message.strip():
213
+ yield history
214
+ return
215
+
216
+ # Add user message
217
+ history.append(ChatMessage(role="user", content=message))
218
+ yield history
219
+
220
+ try:
221
+ print_info(f"Processing query: {message}")
222
+
223
+ # Stream agent output using stream_to_gradio
224
+ for chat_message in stream_to_gradio(self.agent, message):
225
+ # Parse for <think> tags
226
+ content = chat_message.content if isinstance(chat_message.content, str) else str(chat_message.content)
227
+ thoughts, clean_content = self._parse_thinking_tags(content)
228
+
229
+ # Display thinking content if present
230
+ for thought in thoughts:
231
+ history.append(ChatMessage(
232
+ role="assistant",
233
+ content=thought.strip(),
234
+ metadata={"title": "🧠 Model Thinking"}
235
+ ))
236
+ yield history
237
+
238
+ # Add the message with cleaned content
239
+ if clean_content:
240
+ if hasattr(chat_message, 'metadata') and chat_message.metadata:
241
+ # Preserve original metadata from stream_to_gradio
242
+ history.append(ChatMessage(
243
+ role=chat_message.role,
244
+ content=clean_content,
245
+ metadata=chat_message.metadata
246
+ ))
247
+ else:
248
+ # Regular message without metadata
249
+ history.append(ChatMessage(
250
+ role=chat_message.role,
251
+ content=clean_content
252
+ ))
253
+ yield history
254
+
255
+ print_success("Query processed successfully!")
256
+
257
+ except Exception as e:
258
+ error_msg = f"Error processing query: {str(e)}"
259
+ print_error(error_msg)
260
+ # Remove pending messages if present
261
+ if history and len(history) > 0:
262
+ last_msg = history[-1]
263
+ if hasattr(last_msg, 'metadata') and last_msg.metadata and last_msg.metadata.get('status') == 'pending':
264
+ history = history[:-1]
265
+ history.append(ChatMessage(role="assistant", content=error_msg))
266
+ yield history
267
+
268
+ def cleanup(self):
269
+ """Clean up resources."""
270
+ if self.mcp_client:
271
+ try:
272
+ self.mcp_client.__exit__(None, None, None)
273
+ except Exception as e:
274
+ print_error(f"Error during cleanup: {e}")
275
+
276
+
277
+ def create_gradio_interface(agent: KnowledgeGraphChatAgent):
278
+ """Create the Gradio chat interface with model configuration."""
279
+
280
+ with gr.Blocks(title="πŸ€— Transformers Q&A Agent", theme=gr.themes.Soft()) as demo:
281
+
282
+ # ==================== INITIALIZATION SECTION ====================
283
+ with gr.Column(visible=not agent.is_ready()) as init_section:
284
+ gr.Markdown("""
285
+ # πŸ€— Transformers Library Q&A Agent
286
+
287
+ Welcome! This AI agent helps you understand the **Hugging Face Transformers** library.
288
+ Ask questions about the codebase, find functions, explore classes, and understand how components work together.
289
+
290
+ Configure your AI model below to get started. The MCP server tools are already connected!
291
+ """)
292
+
293
+ with gr.Group():
294
+ gr.Markdown("### βš™οΈ Model Configuration")
295
+
296
+ with gr.Row():
297
+ model_type = gr.Dropdown(
298
+ choices=["openai", "azure", "hf_inference"],
299
+ value="openai",
300
+ label="Model Type",
301
+ info="Choose between OpenAI, Azure OpenAI, or HuggingFace Inference"
302
+ )
303
+
304
+ # Model name field (shown for all types)
305
+ with gr.Row() as model_name_row:
306
+ model_name = gr.Textbox(
307
+ label="Model Name",
308
+ value=os.environ.get('OPENAI_MODEL_NAME', 'gpt-4o-mini'),
309
+ info="e.g., gpt-4o-mini, gpt-4, Qwen/Qwen2.5-Coder-32B-Instruct"
310
+ )
311
+
312
+ # OpenAI specific fields
313
+ with gr.Row(visible=True) as openai_fields:
314
+ api_key = gr.Textbox(
315
+ label="API Key",
316
+ value=os.environ.get('OPENAI_API_KEY', ''),
317
+ type="password",
318
+ info="Your OpenAI API key"
319
+ )
320
+ base_url = gr.Textbox(
321
+ label="Base URL",
322
+ value=os.environ.get('OPENAI_BASE_URL', 'https://api.openai.com/v1'),
323
+ info="API endpoint URL"
324
+ )
325
+
326
+ # Azure specific fields
327
+ with gr.Row(visible=False) as azure_fields:
328
+ azure_api_key = gr.Textbox(
329
+ label="Azure API Key",
330
+ value=os.environ.get('AZURE_OPENAI_API_KEY', ''),
331
+ type="password",
332
+ info="Your Azure OpenAI API key"
333
+ )
334
+ azure_endpoint = gr.Textbox(
335
+ label="Azure Endpoint",
336
+ value=os.environ.get('AZURE_OPENAI_ENDPOINT', ''),
337
+ info="Azure OpenAI endpoint URL"
338
+ )
339
+
340
+ with gr.Row(visible=False) as azure_version_row:
341
+ api_version = gr.Textbox(
342
+ label="API Version",
343
+ value=os.environ.get('OPENAI_API_VERSION', '2024-02-15-preview'),
344
+ info="Azure OpenAI API version"
345
+ )
346
+
347
+ # HuggingFace Inference specific fields
348
+ with gr.Row(visible=False) as hf_fields:
349
+ hf_token = gr.Textbox(
350
+ label="HuggingFace Token",
351
+ value=os.environ.get('HF_TOKEN', ''),
352
+ type="password",
353
+ info="Your HuggingFace API token"
354
+ )
355
+ hf_provider = gr.Textbox(
356
+ label="Inference Provider (Optional)",
357
+ value=os.environ.get('HF_INFERENCE_PROVIDER', ''),
358
+ info="Provider name (e.g., 'together', 'fireworks-ai', 'cerebras'). Leave empty for auto."
359
+ )
360
+
361
+ with gr.Row():
362
+ max_steps = gr.Number(
363
+ label="Max Steps",
364
+ value=int(os.getenv("MAX_STEPS", 5)),
365
+ minimum=1,
366
+ maximum=20,
367
+ info="Maximum reasoning steps for the agent"
368
+ )
369
+
370
+ init_status = gr.Markdown("**Status:** ⚠️ Please configure and initialize the agent")
371
+ init_btn = gr.Button("πŸš€ Initialize Agent", variant="primary", size="lg")
372
+
373
+ # Toggle visibility based on model type
374
+ def toggle_model_fields(mtype):
375
+ if mtype == "azure":
376
+ return (
377
+ gr.update(visible=False), # openai_fields
378
+ gr.update(visible=True), # azure_fields
379
+ gr.update(visible=True), # azure_version_row
380
+ gr.update(visible=False) # hf_fields
381
+ )
382
+ elif mtype == "hf_inference":
383
+ return (
384
+ gr.update(visible=False), # openai_fields
385
+ gr.update(visible=False), # azure_fields
386
+ gr.update(visible=False), # azure_version_row
387
+ gr.update(visible=True) # hf_fields
388
+ )
389
+ else: # openai
390
+ return (
391
+ gr.update(visible=True), # openai_fields
392
+ gr.update(visible=False), # azure_fields
393
+ gr.update(visible=False), # azure_version_row
394
+ gr.update(visible=False) # hf_fields
395
+ )
396
+
397
+ model_type.change(
398
+ fn=toggle_model_fields,
399
+ inputs=[model_type],
400
+ outputs=[openai_fields, azure_fields, azure_version_row, hf_fields]
401
+ )
402
+
403
+ # ==================== CHAT SECTION ====================
404
+ with gr.Column(visible=agent.is_ready()) as chat_section:
405
+ gr.Markdown("""
406
+ # πŸ€— Transformers Library Q&A Agent
407
+
408
+ Ask me anything about the **Hugging Face Transformers** library! I can help you:
409
+ - πŸ” Find and explain functions, classes, and methods
410
+ - πŸ—ΊοΈ Navigate the codebase structure and understand file organization
411
+ - πŸ”— Trace relationships and dependencies between components
412
+ - πŸ“– Explain implementation details and design patterns
413
+ """)
414
+
415
+ chatbot = gr.Chatbot(
416
+ label="Transformers Q&A",
417
+ height=500,
418
+ show_copy_button=True,
419
+ type="messages"
420
+ )
421
+
422
+ with gr.Row():
423
+ msg = gr.Textbox(
424
+ label="Your Question",
425
+ placeholder="Ask about the Transformers library... (e.g., 'How does BertModel work?')",
426
+ scale=4,
427
+ lines=1
428
+ )
429
+ submit_btn = gr.Button("Send", variant="primary", scale=1)
430
+
431
+ with gr.Row():
432
+ clear_btn = gr.Button("Clear Chat", variant="secondary")
433
+
434
+ gr.Markdown("""
435
+ ### πŸ’‘ Example Questions:
436
+ - "How does the `AutoModel` class work?"
437
+ - "What is the structure of a model's `forward` method?"
438
+ - "Find all classes that inherit from `PreTrainedModel`"
439
+ - "How does tokenization work in the library?"
440
+ - "What files are involved in the BERT implementation?"
441
+ """)
442
+
443
+ # Handle agent initialization
444
+ def initialize_agent(mtype, mname, akey, burl, azure_akey, azure_ep, aversion, hf_tok, hf_prov, msteps):
445
+ try:
446
+ if mtype == "azure":
447
+ agent._initialize_model(
448
+ model_type=mtype,
449
+ api_key=azure_akey,
450
+ base_url=azure_ep,
451
+ model_name=mname,
452
+ api_version=aversion
453
+ )
454
+ elif mtype == "hf_inference":
455
+ agent._initialize_model(
456
+ model_type=mtype,
457
+ api_key=hf_tok,
458
+ model_name=mname,
459
+ base_url=hf_prov if hf_prov else None
460
+ )
461
+ else: # openai
462
+ agent._initialize_model(
463
+ model_type=mtype,
464
+ api_key=akey,
465
+ base_url=burl,
466
+ model_name=mname
467
+ )
468
+ agent._initialize_agent(max_steps=int(msteps))
469
+ return (
470
+ gr.update(value="**Status:** βœ… Agent Ready!"),
471
+ gr.update(visible=False), # Hide init section
472
+ gr.update(visible=True) # Show chat section
473
+ )
474
+ except Exception as e:
475
+ error_msg = f"**Status:** ❌ Initialization failed: {str(e)}"
476
+ return (
477
+ gr.update(value=error_msg),
478
+ gr.update(visible=True), # Keep init section visible
479
+ gr.update(visible=False) # Keep chat section hidden
480
+ )
481
+
482
+ init_btn.click(
483
+ fn=initialize_agent,
484
+ inputs=[model_type, model_name, api_key, base_url, azure_api_key, azure_endpoint, api_version, hf_token, hf_provider, max_steps],
485
+ outputs=[init_status, init_section, chat_section]
486
+ )
487
+
488
+ # Handle message submission with streaming
489
+ def submit_message(message, history):
490
+ for updated_history in agent.chat(message, history):
491
+ yield "", updated_history
492
+
493
+ submit_btn.click(
494
+ fn=submit_message,
495
+ inputs=[msg, chatbot],
496
+ outputs=[msg, chatbot]
497
+ )
498
+
499
+ msg.submit(
500
+ fn=submit_message,
501
+ inputs=[msg, chatbot],
502
+ outputs=[msg, chatbot]
503
+ )
504
+
505
+ clear_btn.click(
506
+ fn=lambda: [],
507
+ outputs=chatbot
508
+ )
509
+
510
+ return demo
511
+
512
+
513
+ def main():
514
+ parser = argparse.ArgumentParser(description="Smolagents Chat Agent with Gradio Interface")
515
+ parser.add_argument("--mcp-server-url", type=str, help="URL of the MCP server")
516
+ parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
517
+ parser.add_argument("--port", type=int, default=7861, help="Port to bind to")
518
+ parser.add_argument("--share", action="store_true", help="Create a public link")
519
+
520
+ args = parser.parse_args()
521
+
522
+ try:
523
+ # Initialize the agent
524
+ print_info("Initializing Knowledge Graph Chat Agent...")
525
+ agent = KnowledgeGraphChatAgent(mcp_server_url=args.mcp_server_url)
526
+ print_success("Agent ready!")
527
+
528
+ # Create and launch the Gradio interface
529
+ demo = create_gradio_interface(agent)
530
+ print_info(f"Launching Gradio interface on {args.host}:{args.port}")
531
+
532
+ demo.launch(
533
+ server_name=args.host,
534
+ server_port=args.port,
535
+ share=args.share
536
+ )
537
+
538
+ except KeyboardInterrupt:
539
+ print_info("\nShutting down gracefully...")
540
+ except Exception as e:
541
+ print_error(f"Fatal error: {e}")
542
+ import traceback
543
+ traceback.print_exc()
544
+ sys.exit(1)
545
+ finally:
546
+ if 'agent' in locals():
547
+ agent.cleanup()
548
+
549
+
550
+ if __name__ == "__main__":
551
+ main()