pharmaia commited on
Commit
452666e
·
verified ·
1 Parent(s): 17027bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +318 -508
app.py CHANGED
@@ -1,550 +1,360 @@
1
- import gradio as gr
2
- from llama_index.core import VectorStoreIndex
3
- from llama_index.core import (
4
- StorageContext,
5
- load_index_from_storage,
6
- )
7
- from llama_index.tools.arxiv import ArxivToolSpec
8
- from llama_index.core import Settings
9
- from llama_index.llms.azure_openai import AzureOpenAI
10
- from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
11
- from llama_index.llms.openai import OpenAI
12
- from llama_index.embeddings.openai import OpenAIEmbedding
13
- from typing import Optional, List, Dict, Any
14
- from pathlib import Path
15
- import aiohttp
16
  import json
17
  import os
18
- import asyncio
 
 
 
 
 
19
 
 
 
 
20
 
21
- from gradio_client import Client, handle_file
22
- HF_TOKEN = os.environ.get('HF_TOKEN')
23
 
 
24
 
 
 
25
 
26
- ##### LLM #####
27
- openai_api_key = os.environ.get('OPENAI_API_KEY')
28
 
 
 
 
 
 
 
 
 
 
29
 
30
- llm = OpenAI(
31
- model="gpt-4.1",
32
- api_key=openai_api_key,
33
- )
34
- embed_model = OpenAIEmbedding(
35
- model="text-embedding-ada-002",
36
- api_key=openai_api_key,
37
- )
38
 
39
- Settings.llm = llm
40
- Settings.embed_model = embed_model
41
- ##### END LLM #####
42
-
43
-
44
-
45
- ##### LOAD RETRIEVERS #####
46
- DOCUMENTS_BASE_PATH = "./"
47
- RETRIEVERS_JSON_PATH = Path("./retrievers.json")
48
-
49
- # Load metadata
50
- def load_retrievers_metadata():
51
- try:
52
- with open(RETRIEVERS_JSON_PATH, 'r', encoding='utf-8') as f:
53
- return json.load(f)
54
- except Exception as e:
55
- print(f"Error loading retrievers.json: {str(e)}")
56
- print(f"Error details: {traceback.format_exc()}") # You would need to import traceback
57
- return {}
58
-
59
- retrievers_metadata = load_retrievers_metadata()
60
- SOURCES = {source: f"{source.lower()}/" for source in retrievers_metadata.keys()}
61
-
62
- # Load indexes
63
- indices: Dict[str, VectorStoreIndex] = {}
64
-
65
- for source, rel_path in SOURCES.items():
66
- full_path = os.path.join(DOCUMENTS_BASE_PATH, rel_path)
67
- if not os.path.exists(full_path):
68
- print(f"Warning: Path not found for {source}")
69
- continue
70
-
71
- for root, dirs, files in os.walk(full_path):
72
- if "storage_nodes" in dirs:
73
- try:
74
- storage_path = os.path.join(root, "storage_nodes")
75
- storage_context = StorageContext.from_defaults(persist_dir=storage_path)
76
- index_name = os.path.basename(root)
77
- indices[index_name] = load_index_from_storage(storage_context) #, index_id="vector_index"
78
- print(f"Index loaded successfully: {index_name}")
79
- except Exception as e:
80
- print(f"Error loading index {index_name}: {str(e)}")
81
- print(f"Error details: {traceback.format_exc()}")
82
-
83
-
84
-
85
-
86
-
87
-
88
- ##### ARXIV INSTANCE #####
89
- arxiv_tool = ArxivToolSpec(max_results=5).to_tool_list()[0]
90
- arxiv_tool.return_direct = True
91
-
92
-
93
-
94
- ##### MCP TOOLS #####
95
-
96
- async def search_arxiv(
97
- query: str,
98
- max_results: int = 5
99
- ) -> Dict[str, Any]:
100
- """
101
- Searches for academic papers on ArXiv.
102
-
103
- Args:
104
- query: Search terms (e.g. "deep learning")
105
- max_results: Maximum number of results (1-10, default 5)
106
-
107
- Returns:
108
- Dict: Search results with paper metadata
109
- """
110
- try:
111
- # Configure maximum results
112
- max_results = min(max(1, max_results), 10)
113
- arxiv_tool.metadata.max_results = max_results
114
-
115
- # Execute search and get results
116
- tool_output = arxiv_tool(query=query)
117
-
118
- # Process documents
119
- papers = []
120
- for doc in tool_output.raw_output: # Correctly access documents
121
- content = doc.text_resource.text.split('\n')
122
- papers.append({
123
- 'title': content[0].split(': ')[1] if ': ' in content[0] else content[0],
124
- 'abstract': '\n'.join(content[1:]).strip(),
125
- 'pdf_url': content[0].split(': ')[0].replace('http://', 'https://'),
126
- 'arxiv_id': content[0].split(': ')[0].split('/')[-1].replace('v1', '')
127
- })
128
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  return {
130
- 'papers': papers,
131
- 'count': len(papers),
132
- 'query': query,
133
- 'status': 'success'
134
  }
135
-
136
- except Exception as e:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  return {
138
- 'papers': [],
139
- 'count': 0,
140
- 'query': query,
141
- 'status': 'error',
142
- 'error': str(e)
143
  }
144
 
145
- async def list_retrievers(source: str = None) -> dict:
146
- """
147
- Returns the list of available retrievers.
148
- If a source is specified and exists, filters by it; if it doesn't exist, returns all.
149
-
150
- Args:
151
- source (str, optional): Source to filter by. If it doesn't exist, it will be ignored. Defaults to None.
152
-
153
- Returns:
154
- dict: {
155
- "retrievers": List of retrievers (filtered or complete),
156
- "count": Total count,
157
- "status": "success"|"error",
158
- "source_requested": source, # Shows what was requested
159
- "source_used": "all"|source # Shows what was actually used
160
- }
161
- """
162
- try:
163
- available = []
164
- source_exists = source in retrievers_metadata if source else False
165
-
166
- for current_source, indexes in retrievers_metadata.items():
167
- # Only filter if source exists, otherwise show all
168
- if source_exists and current_source != source:
169
- continue
170
-
171
- for index_name, metadata in indexes.items():
172
- available.append({
173
- "name": index_name,
174
- "source": current_source,
175
- "title": metadata.get("title", ""),
176
- "description": metadata.get("description", "")
177
- })
178
-
179
  return {
180
- "retrievers": available,
181
- "count": len(available),
182
- "status": "success",
183
- "source_requested": source,
184
- "source_used": source if source_exists else "all"
185
  }
186
- except Exception as e:
 
 
 
 
 
 
 
187
  return {
188
- "retrievers": [],
189
- "count": 0,
190
- "status": "error",
191
- "error": str(e),
192
- "source_requested": source,
193
- "source_used": "none"
 
 
 
194
  }
195
 
196
-
197
- def retrieve_docs(
198
- query: str,
199
- retrievers: List[str],
200
- top_k: int = 3
201
- ) -> dict:
202
- """
203
- Performs semantic search on indexed documents.
204
-
205
- Parameters:
206
- query (str): Search text (required)
207
- retrievers (List[str]): Names of retrievers to query (required)
208
- top_k (int): Number of results per retriever (optional, default=3)
209
- """
210
- print(f"Starting search for query: '{query}'")
211
- print(f"Parameters - retrievers: {retrievers}, top_k: {top_k}")
212
-
213
- results = {}
214
- invalid = []
215
-
216
- for name in retrievers:
217
- if name not in indices:
218
- print(f"Retriever not found: {name}")
219
- invalid.append(name)
220
- continue
221
-
222
- try:
223
- print(f"Processing retriever: {name}")
224
- retriever = indices[name].as_retriever(similarity_top_k=top_k)
225
- nodes = retriever.retrieve(query)
226
- print(f"Retrieved {len(nodes)} documents from {name}")
227
-
228
- # 2. Search for COMPLETE metadata
229
- metadata = {}
230
- source = "unknown"
231
- for src, indexes in retrievers_metadata.items():
232
- if name in indexes:
233
- metadata = indexes[name]
234
- source = src
235
- break
236
- print(f"Metadata found for {name}: {metadata.keys()}")
237
-
238
- # 3. Build response
239
- results[name] = {
240
- "title": metadata.get("title", name),
241
- "documents": [
242
- {
243
- "content": node.get_content(),
244
- "metadata": node.metadata,
245
- "score": node.score
246
- }
247
- for node in nodes
248
- ],
249
- "description": metadata.get("description", ""),
250
- "source": source,
251
- "last_updated": metadata.get("last_updated", "")
252
- }
253
- print(f"Retriever {name} processed successfully")
254
-
255
- except Exception as e:
256
- print(f"Error processing retriever {name}: {str(e)}", exc_info=True)
257
- results[name] = {
258
- "error": str(e),
259
- "retriever": name
260
- }
261
-
262
- # Build final response
263
- response = {
264
- "query": query,
265
- "results": results,
266
- "top_k": top_k,
267
- }
268
-
269
- if invalid:
270
- print(f"Invalid retrievers: {invalid}. Valid options: {list(indices.keys())}")
271
- response["warnings"] = {
272
- "invalid_retrievers": invalid,
273
- "valid_options": list(indices.keys())
274
- }
275
-
276
- print(f"Search completed. Total results: {len(results)}")
277
- return response
278
-
279
-
280
- async def search_tavily(
281
- query: str,
282
- days: int = 7,
283
- max_results: int = 1,
284
- include_answer: bool = False
285
- ) -> dict:
286
- """Perform a web search using the Tavily API.
287
-
288
- Args:
289
- query: Search query string (required)
290
- days: Restrict search to last N days (default: 7)
291
- max_results: Maximum results to return (default: 1)
292
- include_answer: Include a direct answer only when requested by the user (default: False)
293
-
294
- Returns:
295
- dict: Search results from Tavily
296
- """
297
- # Get API key from environment variables
298
- tavily_api_key = os.environ.get('TAVILY_API_KEY')
299
- if not tavily_api_key:
300
- raise ValueError("TAVILY_API_KEY environment variable not set")
301
-
302
- headers = {
303
- "Authorization": f"Bearer {tavily_api_key}",
304
- "Content-Type": "application/json"
305
- }
306
-
307
- payload = {
308
- "query": query,
309
- "search_depth": "basic",
310
- "max_results": max_results,
311
- "days": days if days else None,
312
- "include_answer": include_answer
313
- }
314
-
315
- try:
316
- async with aiohttp.ClientSession() as session:
317
- async with session.post(
318
- "https://api.tavily.com/search",
319
- headers=headers,
320
- json=payload
321
- ) as response:
322
- response.raise_for_status()
323
- result = await response.json()
324
- return result
325
-
326
- except Exception as e:
327
  return {
328
- "error": str(e),
329
- "status": "failed",
330
- "query": query
331
  }
332
 
333
- ##### EVALS #####
334
- async def evaluate_answer_relevancy(
335
- query: str,
336
- response: str,
337
- ) -> float:
338
- """Evaluate how relevant the answer is to the query using AnswerRelevancyEvaluator.
339
-
340
- Args:
341
- query: Original user query (required)
342
- response: Generated response to evaluate (required)
343
-
344
- Returns:
345
- float: Relevancy score between 0 and 1 (higher is better)
346
- """
347
- try:
348
- from llama_index.core.evaluation import AnswerRelevancyEvaluator
349
-
350
- # Initialize the evaluator
351
- evaluator = AnswerRelevancyEvaluator(llm=llm)
352
-
353
- # Perform the evaluation
354
- eval_result = evaluator.evaluate(query=query, response=response)
355
-
356
- # Return the score as a float
357
- return float(eval_result.score)
358
-
359
- except Exception as e:
360
- # In case of error, return 0.0 (minimum score) and log the error
361
- print(f"Error in relevancy evaluation: {str(e)}")
362
- return 0.0
363
-
364
- async def evaluate_context_relevancy(
365
- context: str,
366
- query: str,
367
- response: str
368
- ) -> float:
369
- """Evaluates the relevance of the response considering both the query and the context.
370
-
371
- Args:
372
- context: Contextual information / knowledge base (required)
373
- query: Original user query (required)
374
- response: Generated response to evaluate (required)
375
-
376
- Returns:
377
- float: Relevance score between 0 and 1 (higher is better)
378
- """
379
- try:
380
- from llama_index.core.evaluation import ContextRelevancyEvaluator
381
-
382
- # Initialize the relevancy evaluator with context
383
- evaluator = ContextRelevancyEvaluator(llm=llm)
384
-
385
- # Perform the evaluation (adapted to handle context)
386
- eval_result = evaluator.evaluate(
387
- query=query,
388
- response=response,
389
- contexts=[context]
390
- )
391
-
392
- return float(eval_result.score)
393
-
394
- except Exception as e:
395
- print(f"Error during context relevancy evaluation: {str(e)}")
396
- return 0.0
397
-
398
- async def evaluate_faithfulness(
399
- query: str,
400
- response: str,
401
- context: str
402
- ) -> float:
403
- """Evaluate how faithful (factually consistent) the response is to the provided context.
404
-
405
- Args:
406
- query: Original user query (required)
407
- response: Generated response to evaluate (required)
408
- context: Source context/knowledge base used for the response (required)
409
-
410
- Returns:
411
- float: Faithfulness score between 0 and 1 (higher is better)
412
- """
413
- try:
414
- from llama_index.core.evaluation import FaithfulnessEvaluator
415
-
416
- # Initialize evaluator
417
- evaluator = FaithfulnessEvaluator(llm=llm)
418
-
419
- # Perform evaluation
420
- eval_result = evaluator.evaluate(
421
- query=query,
422
- response=response,
423
- contexts=[context]
424
- )
425
-
426
- # Return score as float
427
- return float(eval_result.score)
428
-
429
- except Exception as e:
430
- # On error, return 0.0 (minimum score) and log the error
431
- print(f"Error in faithfulness evaluation: {str(e)}")
432
- return 0.0
433
 
434
 
 
435
 
436
 
 
 
437
 
438
 
 
 
439
 
440
 
 
 
441
 
442
- # Gradio interface
443
- with gr.Blocks(title="MCP Tools", theme=gr.themes.Base()) as arxiv_tab:
444
- arxiv_interface = gr.Interface(
445
- fn=search_arxiv,
446
- inputs=[
447
- gr.Textbox(label="Search terms", placeholder="E.g.: deep learning"),
448
- gr.Slider(1, 10, value=5, step=1, label="Maximum number of results")
449
- ],
450
- outputs=gr.JSON(label="Search results"),
451
- title="ArXiv Search",
452
- description="Search for academic papers on ArXiv using keywords.",
453
- api_name="_search_arxiv"
454
- )
455
 
456
- with gr.Blocks(title="MCP Tools", theme=gr.themes.Base()) as list_retrievers_tab:
457
- retrievers_interface = gr.Interface(
458
- fn=list_retrievers,
459
- inputs=gr.Textbox(label="Source (optional)", placeholder="Leave empty to list all"),
460
- outputs=gr.JSON(label="List of retrievers"),
461
- title="List of Retrievers",
462
- description="Shows available retrievers, optionally filtered by source.",
463
- api_name="_list_retrievers"
464
- )
465
 
466
- with gr.Blocks(title="MCP Tools", theme=gr.themes.Base()) as tavily_tab:
467
- tavily_interface = gr.Interface(
468
- fn=search_tavily,
469
- inputs=[
470
- gr.Textbox(label="Search query", placeholder="E.g.: latest news about AI"),
471
- gr.Slider(1, 30, value=7, step=1, label="Last N days (0 for no limit)"),
472
- gr.Slider(1, 10, value=1, step=1, label="Maximum results"),
473
- gr.Checkbox(label="Include direct answer", value=False)
474
- ],
475
- outputs=gr.JSON(label="Tavily results"),
476
- title="Web Search (Tavily)",
477
- description="Perform web searches using the Tavily API.",
478
- api_name="_search_tavily"
479
- )
480
-
481
- with gr.Blocks(title="MCP Tools", theme=gr.themes.Base()) as retrieve_tab:
482
- # Interface for retrieve_docs
483
- retrieve_interface = gr.Interface(
484
- fn=retrieve_docs,
485
- inputs=[
486
- gr.Textbox(label="Query", placeholder="Enter your question or search terms..."),
487
- gr.Dropdown(
488
- choices=list(indices.keys()),
489
- label="Retrievers",
490
- multiselect=True,
491
- info="Select one or more retrievers"
492
- ),
493
- gr.Slider(1, 10, value=3, step=1, label="Number of results per retriever (top_k)")
494
- ],
495
- outputs=gr.JSON(label="Semantic search results"),
496
- title="Semantic Document Search",
497
- description="""Perform semantic search on indexed documents using retrievers.
498
- Select available retrievers and adjust the number of results.""",
499
- api_name="_retrieve"
500
- )
501
 
502
- with gr.Blocks(title="MCP Tools", theme=gr.themes.Base()) as asw_relevance_tab:
503
- relevancy_interface = gr.Interface(
504
- fn=evaluate_answer_relevancy,
505
- inputs=[
506
- gr.Textbox(label="Original Query", placeholder="E.g.: How does photosynthesis work?"),
507
- gr.Textbox(label="Answer to Evaluate", placeholder="Paste the generated answer here", lines=5),
508
- ],
509
- outputs=gr.Number(label="Relevancy Score (0-1)", precision=3),
510
- title="Relevancy Evaluator (Query-Answer)",
511
- description="Evaluates how relevant an answer is to the original query (1 = perfectly relevant).",
512
- api_name="_evaluate_relevancy"
513
- )
514
 
515
- with gr.Blocks(title="MCP Tools", theme=gr.themes.Base()) as ctx_relevance_tab:
516
- context_relevancy_interface = gr.Interface(
517
- fn=evaluate_context_relevancy,
518
- inputs=[
519
- gr.Textbox(label="Context", placeholder="Relevant text / knowledge base", lines=3),
520
- gr.Textbox(label="Original Query", placeholder="What question is being answered?"),
521
- gr.Textbox(label="Generated Answer", placeholder="The answer to evaluate", lines=5),
522
- ],
523
- outputs=gr.Number(label="Relevancy Score (0-1)", precision=3),
524
- title="Relevancy Evaluator (Context-Query-Answer)",
525
- description="Evaluates how relevant the answer is considering both the query and the reference context.",
526
- api_name="_evaluate_context_relevancy"
527
- )
528
 
529
- with gr.Blocks(title="MCP Tools", theme=gr.themes.Base()) as faithfulness_tab:
530
- faithfulness_interface = gr.Interface(
531
- fn=evaluate_faithfulness,
532
- inputs=[
533
- gr.Textbox(label="Original Query", placeholder="E.g.: What are the causes of climate change?"),
534
- gr.Textbox(label="Answer to Evaluate", placeholder="Paste the generated answer here", lines=5),
535
- gr.Textbox(label="Context", placeholder="Reference text / knowledge base", lines=3),
536
- ],
537
- outputs=gr.Number(label="Faithfulness Score (0-1)", precision=3),
538
- title="Faithfulness Evaluator",
539
- description="Evaluates how faithful/factually consistent the answer is with respect to the provided context (1 = perfectly faithful).",
540
- api_name="_evaluate_faithfulness"
541
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
542
 
543
- # Create the interface with separate tabs
544
  demo = gr.TabbedInterface(
545
- [arxiv_tab, tavily_tab, list_retrievers_tab, retrieve_tab, asw_relevance_tab, ctx_relevance_tab, faithfulness_tab],
546
- ["ArXiv", "Tavily", "List Retrievers", "Retrieve", "Answer Relevance", "Context Relevance", "Faithfulness"],
547
  theme=gr.themes.Base(),
548
  )
549
 
550
- demo.launch(mcp_server=True)
 
 
1
+ import base64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import json
3
  import os
4
+ import secrets
5
+ import time
6
+ import urllib.parse
7
+ from dataclasses import dataclass
8
+ from pathlib import Path
9
+ from typing import Any
10
 
11
+ import gradio as gr
12
+ import requests
13
+ from dotenv import load_dotenv
14
 
 
 
15
 
16
+ load_dotenv()
17
 
18
+ SPOTIFY_ACCOUNTS_BASE = "https://accounts.spotify.com"
19
+ SPOTIFY_API_BASE = "https://api.spotify.com/v1"
20
 
 
 
21
 
22
+ @dataclass
23
+ class AuthConfig:
24
+ client_id: str
25
+ client_secret: str
26
+ redirect_uri: str
27
+ scopes: str
28
+ token_file: Path
29
+ state_file: Path
30
+ env_refresh_token: str | None
31
 
 
 
 
 
 
 
 
 
32
 
33
+ class SpotifyClient:
34
+ def __init__(self, config: AuthConfig) -> None:
35
+ self.config = config
36
+
37
+ def _read_json(self, path: Path) -> dict[str, Any] | None:
38
+ if not path.exists():
39
+ return None
40
+ try:
41
+ return json.loads(path.read_text(encoding="utf-8"))
42
+ except (json.JSONDecodeError, OSError):
43
+ return None
44
+
45
+ def _write_json(self, path: Path, payload: dict[str, Any]) -> None:
46
+ path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
47
+
48
+ def _read_token_data(self) -> dict[str, Any] | None:
49
+ return self._read_json(self.config.token_file)
50
+
51
+ def _write_token_data(self, token_data: dict[str, Any]) -> None:
52
+ token_data["saved_at"] = int(time.time())
53
+ self._write_json(self.config.token_file, token_data)
54
+
55
+ def _basic_auth_header(self) -> str:
56
+ raw = f"{self.config.client_id}:{self.config.client_secret}".encode("utf-8")
57
+ return base64.b64encode(raw).decode("ascii")
58
+
59
+ def _token_is_expired(self, token_data: dict[str, Any]) -> bool:
60
+ expires_in = int(token_data.get("expires_in", 0))
61
+ saved_at = int(token_data.get("saved_at", 0))
62
+ return time.time() >= saved_at + expires_in - 60
63
+
64
+ def _request_token(self, data: dict[str, Any]) -> dict[str, Any]:
65
+ response = requests.post(
66
+ f"{SPOTIFY_ACCOUNTS_BASE}/api/token",
67
+ headers={
68
+ "Authorization": f"Basic {self._basic_auth_header()}",
69
+ "Content-Type": "application/x-www-form-urlencoded",
70
+ },
71
+ data=data,
72
+ timeout=30,
73
+ )
74
+ response.raise_for_status()
75
+ return response.json()
76
+
77
+ def _refresh_token(self, refresh_token: str) -> dict[str, Any]:
78
+ new_token = self._request_token(
79
+ {
80
+ "grant_type": "refresh_token",
81
+ "refresh_token": refresh_token,
82
+ }
83
+ )
84
+ new_token["refresh_token"] = new_token.get("refresh_token", refresh_token)
85
+ self._write_token_data(new_token)
86
+ return new_token
87
+
88
+ def _get_access_token(self) -> str:
89
+ token_data = self._read_token_data()
90
+ if token_data:
91
+ if self._token_is_expired(token_data):
92
+ refresh = token_data.get("refresh_token") or self.config.env_refresh_token
93
+ if not refresh:
94
+ raise RuntimeError("Token expired and no refresh token is available.")
95
+ token_data = self._refresh_token(refresh)
96
+ token = token_data.get("access_token")
97
+ if token:
98
+ return token
99
+
100
+ if self.config.env_refresh_token:
101
+ token_data = self._refresh_token(self.config.env_refresh_token)
102
+ token = token_data.get("access_token")
103
+ if token:
104
+ return token
105
+
106
+ raise RuntimeError(
107
+ "No auth session. Set SPOTIFY_REFRESH_TOKEN in env, or run spotify_auth_url and spotify_exchange_code."
108
+ )
109
+
110
+ def _request(self, method: str, path: str, **kwargs: Any) -> dict[str, Any]:
111
+ headers = kwargs.pop("headers", {})
112
+ headers["Authorization"] = f"Bearer {self._get_access_token()}"
113
+ headers.setdefault("Content-Type", "application/json")
114
+
115
+ response = requests.request(
116
+ method,
117
+ f"{SPOTIFY_API_BASE}{path}",
118
+ headers=headers,
119
+ timeout=30,
120
+ **kwargs,
121
+ )
122
+ if response.status_code >= 400:
123
+ raise RuntimeError(f"Spotify API error {response.status_code}: {response.text}")
124
+ if response.status_code == 204:
125
+ return {"ok": True}
126
+ return response.json()
127
+
128
+ def auth_url(self, state: str | None = None) -> dict[str, Any]:
129
+ state_value = state or secrets.token_urlsafe(24)
130
+ self._write_json(self.config.state_file, {"state": state_value, "saved_at": int(time.time())})
131
+ url = (
132
+ f"{SPOTIFY_ACCOUNTS_BASE}/authorize?"
133
+ + urllib.parse.urlencode(
134
+ {
135
+ "client_id": self.config.client_id,
136
+ "response_type": "code",
137
+ "redirect_uri": self.config.redirect_uri,
138
+ "scope": self.config.scopes,
139
+ "state": state_value,
140
+ "show_dialog": "true",
141
+ }
142
+ )
143
+ )
144
  return {
145
+ "auth_url": url,
146
+ "state": state_value,
147
+ "redirect_uri": self.config.redirect_uri,
148
+ "next_step": "Open auth_url, approve app, then call spotify_exchange_code with returned code and state.",
149
  }
150
+
151
+ def exchange_code(self, code: str, state: str | None = None) -> dict[str, Any]:
152
+ state_data = self._read_json(self.config.state_file) or {}
153
+ expected = state_data.get("state")
154
+ if expected and state and state != expected:
155
+ raise RuntimeError("OAuth state mismatch.")
156
+
157
+ token_data = self._request_token(
158
+ {
159
+ "grant_type": "authorization_code",
160
+ "code": code,
161
+ "redirect_uri": self.config.redirect_uri,
162
+ }
163
+ )
164
+ self._write_token_data(token_data)
165
+ me = self.me()
166
  return {
167
+ "status": "ok",
168
+ "user_id": me.get("id"),
169
+ "display_name": me.get("display_name"),
170
+ "has_refresh_token": bool(token_data.get("refresh_token")),
 
171
  }
172
 
173
+ def me(self) -> dict[str, Any]:
174
+ me = self._request("GET", "/me")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  return {
176
+ "id": me.get("id"),
177
+ "display_name": me.get("display_name"),
178
+ "email": me.get("email"),
179
+ "product": me.get("product"),
 
180
  }
181
+
182
+ def search_tracks(self, query: str, limit: int = 5) -> dict[str, Any]:
183
+ payload = self._request(
184
+ "GET",
185
+ "/search",
186
+ params={"q": query, "type": "track", "limit": max(1, min(limit, 50))},
187
+ )
188
+ items = payload.get("tracks", {}).get("items", [])
189
  return {
190
+ "results": [
191
+ {
192
+ "id": t["id"],
193
+ "name": t["name"],
194
+ "artists": ", ".join(a["name"] for a in t.get("artists", [])),
195
+ "uri": t["uri"],
196
+ }
197
+ for t in items
198
+ ]
199
  }
200
 
201
+ def create_playlist(self, name: str, description: str = "", public: bool = False) -> dict[str, Any]:
202
+ playlist = self._request(
203
+ "POST",
204
+ "/me/playlists",
205
+ json={"name": name, "description": description, "public": public},
206
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  return {
208
+ "id": playlist.get("id"),
209
+ "name": playlist.get("name"),
210
+ "url": playlist.get("external_urls", {}).get("spotify"),
211
  }
212
 
213
+ def add_tracks(self, playlist_id: str, track_ids: list[str]) -> dict[str, Any]:
214
+ uris = [tid if tid.startswith("spotify:track:") else f"spotify:track:{tid}" for tid in track_ids]
215
+ payload = self._request("POST", f"/playlists/{playlist_id}/items", json={"uris": uris})
216
+ return {"snapshot_id": payload.get("snapshot_id"), "added": len(uris)}
217
+
218
+
219
+ def load_config() -> AuthConfig:
220
+ client_id = os.getenv("SPOTIFY_CLIENT_ID", "").strip()
221
+ client_secret = os.getenv("SPOTIFY_CLIENT_SECRET", "").strip()
222
+ redirect_uri = os.getenv("SPOTIFY_REDIRECT_URI", "").strip()
223
+ scopes = os.getenv(
224
+ "SPOTIFY_SCOPES",
225
+ "playlist-modify-public playlist-modify-private user-read-private user-read-email",
226
+ ).strip()
227
+ token_file = Path(os.getenv("SPOTIFY_TOKEN_FILE", "spotify_tokens.json"))
228
+ state_file = Path(os.getenv("SPOTIFY_STATE_FILE", "spotify_oauth_state.json"))
229
+ env_refresh_token = os.getenv("SPOTIFY_REFRESH_TOKEN", "").strip() or None
230
+
231
+ if not client_id or not client_secret or not redirect_uri:
232
+ raise RuntimeError("SPOTIFY_CLIENT_ID, SPOTIFY_CLIENT_SECRET and SPOTIFY_REDIRECT_URI are required.")
233
+
234
+ return AuthConfig(
235
+ client_id=client_id,
236
+ client_secret=client_secret,
237
+ redirect_uri=redirect_uri,
238
+ scopes=scopes,
239
+ token_file=token_file,
240
+ state_file=state_file,
241
+ env_refresh_token=env_refresh_token,
242
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
 
244
 
245
+ spotify = SpotifyClient(load_config())
246
 
247
 
248
+ def spotify_auth_url(state: str = "") -> dict[str, Any]:
249
+ return spotify.auth_url(state=state or None)
250
 
251
 
252
+ def spotify_exchange_code(code: str, state: str = "") -> dict[str, Any]:
253
+ return spotify.exchange_code(code=code, state=state or None)
254
 
255
 
256
+ def spotify_me() -> dict[str, Any]:
257
+ return spotify.me()
258
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
+ def search_tracks(query: str, limit: int = 5) -> dict[str, Any]:
261
+ return spotify.search_tracks(query=query, limit=limit)
 
 
 
 
 
 
 
262
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
+ def create_playlist(name: str, description: str = "", public: bool = False) -> dict[str, Any]:
265
+ return spotify.create_playlist(name=name, description=description, public=public)
 
 
 
 
 
 
 
 
 
 
266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
 
268
+ def add_tracks_to_playlist(playlist_id: str, track_ids_csv: str) -> dict[str, Any]:
269
+ track_ids = [x.strip() for x in track_ids_csv.split(",") if x.strip()]
270
+ return spotify.add_tracks(playlist_id=playlist_id, track_ids=track_ids)
271
+
272
+
273
+ def create_playlist_from_search(name: str, query: str, limit: int = 10, public: bool = False) -> dict[str, Any]:
274
+ tracks = spotify.search_tracks(query=query, limit=limit).get("results", [])
275
+ if not tracks:
276
+ raise RuntimeError("No tracks found for query.")
277
+ playlist = spotify.create_playlist(name=name, description=f"Auto playlist for query: {query}", public=public)
278
+ added = spotify.add_tracks(playlist_id=playlist["id"], track_ids=[t["id"] for t in tracks])
279
+ return {"playlist": playlist, "added": added, "tracks": tracks}
280
+
281
+
282
+ auth_tab = gr.Interface(
283
+ fn=spotify_auth_url,
284
+ inputs=gr.Textbox(label="State (optional)", placeholder="optional_state"),
285
+ outputs=gr.JSON(label="Spotify Auth URL"),
286
+ title="Spotify Auth URL",
287
+ api_name="_spotify_auth_url",
288
+ )
289
+
290
+ exchange_tab = gr.Interface(
291
+ fn=spotify_exchange_code,
292
+ inputs=[
293
+ gr.Textbox(label="Authorization Code", placeholder="code from Spotify callback"),
294
+ gr.Textbox(label="State (optional)", placeholder="state from spotify_auth_url"),
295
+ ],
296
+ outputs=gr.JSON(label="Exchange Result"),
297
+ title="Spotify Exchange Code",
298
+ api_name="_spotify_exchange_code",
299
+ )
300
+
301
+ me_tab = gr.Interface(
302
+ fn=spotify_me,
303
+ inputs=[],
304
+ outputs=gr.JSON(label="Profile"),
305
+ title="Spotify Me",
306
+ api_name="_spotify_me",
307
+ )
308
+
309
+ search_tab = gr.Interface(
310
+ fn=search_tracks,
311
+ inputs=[gr.Textbox(label="Query"), gr.Slider(1, 50, value=5, step=1, label="Limit")],
312
+ outputs=gr.JSON(label="Tracks"),
313
+ title="Search Tracks",
314
+ api_name="_search_tracks",
315
+ )
316
+
317
+ create_playlist_tab = gr.Interface(
318
+ fn=create_playlist,
319
+ inputs=[
320
+ gr.Textbox(label="Playlist Name"),
321
+ gr.Textbox(label="Description", placeholder="optional", lines=2),
322
+ gr.Checkbox(label="Public", value=False),
323
+ ],
324
+ outputs=gr.JSON(label="Playlist"),
325
+ title="Create Playlist",
326
+ api_name="_create_playlist",
327
+ )
328
+
329
+ add_tracks_tab = gr.Interface(
330
+ fn=add_tracks_to_playlist,
331
+ inputs=[
332
+ gr.Textbox(label="Playlist ID"),
333
+ gr.Textbox(label="Track IDs CSV", placeholder="id1,id2,id3 or spotify:track:..."),
334
+ ],
335
+ outputs=gr.JSON(label="Add Tracks Result"),
336
+ title="Add Tracks",
337
+ api_name="_add_tracks_to_playlist",
338
+ )
339
+
340
+ auto_tab = gr.Interface(
341
+ fn=create_playlist_from_search,
342
+ inputs=[
343
+ gr.Textbox(label="Playlist Name"),
344
+ gr.Textbox(label="Search Query"),
345
+ gr.Slider(1, 50, value=10, step=1, label="Limit"),
346
+ gr.Checkbox(label="Public", value=False),
347
+ ],
348
+ outputs=gr.JSON(label="Result"),
349
+ title="Create Playlist From Search",
350
+ api_name="_create_playlist_from_search",
351
+ )
352
 
 
353
  demo = gr.TabbedInterface(
354
+ [auth_tab, exchange_tab, me_tab, search_tab, create_playlist_tab, add_tracks_tab, auto_tab],
355
+ ["Auth URL", "Exchange Code", "Me", "Search", "Create Playlist", "Add Tracks", "Auto Playlist"],
356
  theme=gr.themes.Base(),
357
  )
358
 
359
+ if __name__ == "__main__":
360
+ demo.launch(mcp_server=True)