notebooklm-py / api.py
ASEM12345's picture
Upload folder using huggingface_hub
052521c verified
"""
NotebookLM-Py FastAPI REST API
提供 RESTful 接口供 n8n 等工具调用
"""
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional, List
from contextlib import asynccontextmanager
from notebooklm import NotebookLMClient
import os
# ==================== 数据模型 ====================
class CreateNotebookRequest(BaseModel):
title: str
class AddSourceRequest(BaseModel):
url: str
wait: bool = True
class AskRequest(BaseModel):
question: str
class GenerateAudioRequest(BaseModel):
instructions: Optional[str] = None
class GenerateQuizRequest(BaseModel):
difficulty: Optional[str] = "medium"
class NotebookResponse(BaseModel):
id: str
title: str
class SourceResponse(BaseModel):
id: str
title: str
source_type: str
class ChatResponse(BaseModel):
answer: str
citations: Optional[List[str]] = None
class TaskResponse(BaseModel):
task_id: str
status: str
# ==================== 客户端管理 ====================
@asynccontextmanager
async def get_client():
"""获取 NotebookLM 客户端"""
async with await NotebookLMClient.from_storage() as client:
yield client
# ==================== FastAPI 应用 ====================
app = FastAPI(
title="NotebookLM-Py API",
description="非官方 Google NotebookLM REST API",
version="1.0.0",
docs_url="/api/docs",
redoc_url="/api/redoc",
)
# CORS 配置
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ==================== 笔记本 API ====================
@app.get("/api/notebooks", response_model=List[NotebookResponse], tags=["Notebooks"])
async def list_notebooks():
"""列出所有笔记本"""
try:
async with get_client() as client:
notebooks = await client.notebooks.list()
return [NotebookResponse(id=nb.id, title=nb.title) for nb in notebooks]
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/notebooks", response_model=NotebookResponse, tags=["Notebooks"])
async def create_notebook(request: CreateNotebookRequest):
"""创建新笔记本"""
try:
async with get_client() as client:
nb = await client.notebooks.create(request.title)
return NotebookResponse(id=nb.id, title=nb.title)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/notebooks/{notebook_id}", response_model=NotebookResponse, tags=["Notebooks"])
async def get_notebook(notebook_id: str):
"""获取笔记本详情"""
try:
async with get_client() as client:
nb = await client.notebooks.get(notebook_id)
return NotebookResponse(id=nb.id, title=nb.title)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.delete("/api/notebooks/{notebook_id}", tags=["Notebooks"])
async def delete_notebook(notebook_id: str):
"""删除笔记本"""
try:
async with get_client() as client:
await client.notebooks.delete(notebook_id)
return {"message": "Notebook deleted successfully"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# ==================== 资源 API ====================
@app.get("/api/notebooks/{notebook_id}/sources", response_model=List[SourceResponse], tags=["Sources"])
async def list_sources(notebook_id: str):
"""列出笔记本的所有资源"""
try:
async with get_client() as client:
sources = await client.sources.list(notebook_id)
return [SourceResponse(id=s.id, title=s.title, source_type=s.source_type) for s in sources]
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/notebooks/{notebook_id}/sources/url", response_model=SourceResponse, tags=["Sources"])
async def add_url_source(notebook_id: str, request: AddSourceRequest):
"""添加 URL 资源"""
try:
async with get_client() as client:
source = await client.sources.add_url(notebook_id, request.url, wait=request.wait)
return SourceResponse(id=source.id, title=source.title, source_type=source.source_type)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# ==================== 对话 API ====================
@app.post("/api/notebooks/{notebook_id}/ask", response_model=ChatResponse, tags=["Chat"])
async def ask_question(notebook_id: str, request: AskRequest):
"""向笔记本提问"""
try:
async with get_client() as client:
result = await client.chat.ask(notebook_id, request.question)
return ChatResponse(answer=result.answer)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# ==================== 生成内容 API ====================
@app.post("/api/notebooks/{notebook_id}/generate/audio", response_model=TaskResponse, tags=["Generate"])
async def generate_audio(notebook_id: str, request: GenerateAudioRequest = None):
"""生成音频播客"""
try:
async with get_client() as client:
instructions = request.instructions if request else None
status = await client.artifacts.generate_audio(notebook_id, instructions=instructions)
return TaskResponse(task_id=status.task_id, status="started")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/notebooks/{notebook_id}/generate/quiz", response_model=TaskResponse, tags=["Generate"])
async def generate_quiz(notebook_id: str, request: GenerateQuizRequest = None):
"""生成测验"""
try:
async with get_client() as client:
difficulty = request.difficulty if request else "medium"
status = await client.artifacts.generate_quiz(notebook_id, difficulty=difficulty)
return TaskResponse(task_id=status.task_id, status="started")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/notebooks/{notebook_id}/generate/flashcards", response_model=TaskResponse, tags=["Generate"])
async def generate_flashcards(notebook_id: str):
"""生成闪卡"""
try:
async with get_client() as client:
status = await client.artifacts.generate_flashcards(notebook_id)
return TaskResponse(task_id=status.task_id, status="started")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# ==================== 健康检查 ====================
@app.get("/api/health", tags=["System"])
async def health_check():
"""健康检查"""
return {"status": "ok"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)