diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..24a8e87939aa53cdd833f6be7610cb4972e063ad --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +*.png filter=lfs diff=lfs merge=lfs -text diff --git a/.github/workflows/docker-publish.yml b/.github/workflows/docker-publish.yml new file mode 100644 index 0000000000000000000000000000000000000000..5d71c70baa4899f67a45d87e7b7191d382b78292 --- /dev/null +++ b/.github/workflows/docker-publish.yml @@ -0,0 +1,80 @@ +name: Build and Publish Docker Image + +on: + release: + types: [published] + pull_request: + branches: [main] + +env: + REGISTRY: ghcr.io + IMAGE_NAME: ${{ github.repository }} + DOCKERHUB_REGISTRY: docker.io + DOCKERHUB_IMAGE_NAME: variantconst/openwebui-monitor + +jobs: + build: + runs-on: ubuntu-latest + permissions: + contents: read + packages: write + + steps: + # 检出代码 + - name: Checkout repository + uses: actions/checkout@v4 + + # 设置 QEMU(支持多平台) + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + + # 设置 Docker Buildx(支持多平台构建) + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + # 登录到 GitHub Container Registry (GHCR) + - name: Log into GHCR + if: github.event_name != 'pull_request' + uses: docker/login-action@v3 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + # 登录到 Docker Hub + - name: Log into Docker Hub + if: github.event_name != 'pull_request' + uses: docker/login-action@v3 + with: + registry: ${{ env.DOCKERHUB_REGISTRY }} + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + # 提取 Docker 元数据 + - name: Extract Docker metadata + id: meta + uses: docker/metadata-action@v5 + with: + images: | + ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} + ${{ env.DOCKERHUB_REGISTRY }}/${{ env.DOCKERHUB_IMAGE_NAME }} + tags: | + type=ref,event=branch + type=ref,event=pr + type=semver,pattern={{version}} + type=semver,pattern={{major}}.{{minor}} + type=sha + type=raw,value=latest,enable={{is_default_branch}} + + # 构建并推送到 GHCR 和 Docker Hub + - name: Build and push Docker image + uses: docker/build-push-action@v5 + with: + context: . + platforms: linux/amd64,linux/arm64 + push: ${{ github.event_name != 'pull_request' }} + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + cache-from: type=gha + cache-to: type=gha,mode=max + provenance: false diff --git a/VERSION b/VERSION new file mode 100644 index 0000000000000000000000000000000000000000..f969a39aafba5e3a1d12e92743d75629bf0c3abe --- /dev/null +++ b/VERSION @@ -0,0 +1 @@ +v0.3.7 \ No newline at end of file diff --git a/app/api/config/route.ts b/app/api/config/route.ts deleted file mode 100644 index abd1f80c3476b4d2b758130b8d65a45e570e2285..0000000000000000000000000000000000000000 --- a/app/api/config/route.ts +++ /dev/null @@ -1,23 +0,0 @@ -import { NextResponse } from "next/server"; -import { headers } from "next/headers"; - -export async function GET() { - const headersList = headers(); - const token = headersList.get("authorization")?.split(" ")[1]; - const expectedToken = process.env.ACCESS_TOKEN; - - if (!token || token !== expectedToken) { - return NextResponse.json( - { - apiKey: "Unauthorized", - status: 401, - }, - { status: 401 } - ); - } - - return NextResponse.json({ - apiKey: process.env.API_KEY || "Unconfigured", - status: 200, - }); -} diff --git a/app/api/init/route.ts b/app/api/init/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..1a96963f4ae4bff9f0474b1da7874121e1b9e325 --- /dev/null +++ b/app/api/init/route.ts @@ -0,0 +1,22 @@ +import { NextResponse } from "next/server"; +import { initDatabase } from "@/lib/db/client"; + +let initialized = false; + +export async function GET() { + if (!initialized) { + try { + await initDatabase(); + initialized = true; + return NextResponse.json({ success: true, message: "数据库初始化成功" }); + } catch (error) { + console.error("数据库初始化失败:", error); + return NextResponse.json( + { success: false, error: "数据库初始化失败" }, + { status: 500 } + ); + } + } else { + return NextResponse.json({ success: true, message: "数据库已初始化" }); + } +} diff --git a/app/api/config/key/route.ts b/app/api/v1/config/key/route.ts similarity index 58% rename from app/api/config/key/route.ts rename to app/api/v1/config/key/route.ts index ba9684c6423bf1dba735b86bff0fc5a318c822ba..3d8c3d43b204e30ad78fd45cb5f8e106b074af9c 100644 --- a/app/api/config/key/route.ts +++ b/app/api/v1/config/key/route.ts @@ -1,7 +1,12 @@ import { NextResponse } from "next/server"; -import { cookies } from "next/headers"; +import { verifyApiToken } from "@/lib/auth"; + +export async function GET(req: Request) { + const authError = verifyApiToken(req); + if (authError) { + return authError; + } -export async function GET() { const apiKey = process.env.API_KEY; if (!apiKey) { diff --git a/app/api/v1/config/route.ts b/app/api/v1/config/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..73bd86b7862d20f493a88243e02fb0d01292f1f7 --- /dev/null +++ b/app/api/v1/config/route.ts @@ -0,0 +1,14 @@ +import { NextResponse } from "next/server"; +import { verifyApiToken } from "@/lib/auth"; + +export async function GET(req: Request) { + const authError = verifyApiToken(req); + if (authError) { + return authError; + } + + return NextResponse.json({ + apiKey: process.env.API_KEY || "Unconfigured", + status: 200, + }); +} diff --git a/app/api/v1/inlet/route.ts b/app/api/v1/inlet/route.ts index 9352ea28f726e0451c2a1cef680a4bd004527c8b..4f0694edefdd2eecea9215372ef460f732816f3a 100644 --- a/app/api/v1/inlet/route.ts +++ b/app/api/v1/inlet/route.ts @@ -9,7 +9,6 @@ export async function POST(req: Request) { const user = await getOrCreateUser(data.user); const modelId = data.body?.model; - // 如果用户被拉黑,返回余额为 -1 if (user.deleted) { return NextResponse.json({ success: true, @@ -18,10 +17,8 @@ export async function POST(req: Request) { }); } - // 获取预扣费金额 const inletCost = getModelInletCost(modelId); - // 预扣费 if (inletCost > 0) { const userResult = await query( `UPDATE users diff --git a/app/api/v1/models/price/route.ts b/app/api/v1/models/price/route.ts index b7056b258fb36a58410c485a5cad64877c4f39a5..82fa461f5d79c2efdc0038555694d6578ce8200b 100644 --- a/app/api/v1/models/price/route.ts +++ b/app/api/v1/models/price/route.ts @@ -1,5 +1,6 @@ import { NextRequest, NextResponse } from "next/server"; -import { updateModelPrice } from "@/lib/db"; +import { updateModelPrice } from "@/lib/db/client"; +import { verifyApiToken } from "@/lib/auth"; interface PriceUpdate { id: string; @@ -9,11 +10,15 @@ interface PriceUpdate { } export async function POST(request: NextRequest) { + const authError = verifyApiToken(request); + if (authError) { + return authError; + } + try { const data = await request.json(); console.log("Raw data received:", data); - // 从对象中提取模型数组 const updates = data.updates || data; if (!Array.isArray(updates)) { console.error("Invalid data format - expected array:", updates); @@ -23,7 +28,6 @@ export async function POST(request: NextRequest) { ); } - // 验证并转换数据格式 const validUpdates = updates .map((update: any) => ({ id: update.id, @@ -52,7 +56,6 @@ export async function POST(request: NextRequest) { `Successfully verified price updating requests of ${validUpdates.length} models` ); - // 执行批量更新并收集结果 const results = await Promise.all( validUpdates.map(async (update: PriceUpdate) => { try { diff --git a/app/api/v1/models/route.ts b/app/api/v1/models/route.ts index dfd87bf10cb5f1cd6230f4838ea5a01a636604c7..3292d565a90ff598662a637e2f747db9f87c93d6 100644 --- a/app/api/v1/models/route.ts +++ b/app/api/v1/models/route.ts @@ -1,5 +1,6 @@ import { NextResponse } from "next/server"; -import { ensureTablesExist, getOrCreateModelPrices } from "@/lib/db"; +import { ensureTablesExist, getOrCreateModelPrices } from "@/lib/db/client"; +import { verifyApiToken } from "@/lib/auth"; interface ModelInfo { id: string; @@ -21,9 +22,13 @@ interface ModelResponse { }[]; } -export async function GET() { +export async function GET(req: Request) { + const authError = verifyApiToken(req); + if (authError) { + return authError; + } + try { - // Ensure database is initialized await ensureTablesExist(); const domain = process.env.OPENWEBUI_DOMAIN; @@ -31,7 +36,6 @@ export async function GET() { throw new Error("OPENWEBUI_DOMAIN environment variable is not set."); } - // Normalize API URL const apiUrl = domain.replace(/\/+$/, "") + "/api/models"; const response = await fetch(apiUrl, { @@ -47,9 +51,7 @@ export async function GET() { throw new Error(`Failed to fetch models: ${response.status}`); } - // Get response text for debugging const responseText = await response.text(); - // console.log("API response:", responseText); let data: ModelResponse; try { @@ -59,20 +61,25 @@ export async function GET() { throw new Error("Invalid JSON response from API"); } - console.log("data:", data); - if (!data || !Array.isArray(data.data)) { console.error("Unexpected API response structure:", data); throw new Error("Unexpected API response structure"); } - // Get price information for all models + const apiModelsMap = new Map(); + data.data.forEach((item) => { + apiModelsMap.set(String(item.id), { + name: String(item.name), + base_model_id: item.info?.base_model_id || "", + imageUrl: item.info?.meta?.profile_image_url || "/static/favicon.png", + system_prompt: item.info?.params?.system || "", + }); + }); + const modelsWithPrices = await getOrCreateModelPrices( data.data.map((item) => { - // 处理形如 gemini_search.gemini-2.0-flash 的派生模型ID let baseModelId = item.info?.base_model_id; - // 如果没有明确的base_model_id,尝试从ID中提取 if (!baseModelId && item.id) { const idParts = String(item.id).split("."); if (idParts.length > 1) { @@ -88,30 +95,46 @@ export async function GET() { }) ); - const validModels = data.data.map((item, index) => { - // 处理形如 gemini_search.gemini-2.0-flash 的派生模型ID - let baseModelId = item.info?.base_model_id || ""; + const dbModelsMap = new Map(); + modelsWithPrices.forEach((model) => { + dbModelsMap.set(model.id, { + input_price: model.input_price, + output_price: model.output_price, + per_msg_price: model.per_msg_price, + updated_at: model.updated_at, + }); + }); + + const validModels = Array.from(apiModelsMap.entries()).map( + ([id, apiModel]) => { + const dbModel = dbModelsMap.get(id) || { + input_price: 60, + output_price: 60, + per_msg_price: -1, + updated_at: new Date(), + }; - // 如果没有明确的base_model_id,尝试从ID中提取 - if (!baseModelId && item.id) { - const idParts = String(item.id).split("."); - if (idParts.length > 1) { - baseModelId = idParts[idParts.length - 1]; + let baseModelId = apiModel.base_model_id; + if (!baseModelId && id) { + const idParts = String(id).split("."); + if (idParts.length > 1) { + baseModelId = idParts[idParts.length - 1]; + } } - } - return { - id: modelsWithPrices[index].id, - base_model_id: baseModelId, - name: modelsWithPrices[index].name, - imageUrl: item.info?.meta?.profile_image_url || "/static/favicon.png", - system_prompt: item.info?.params?.system || "", - input_price: modelsWithPrices[index].input_price, - output_price: modelsWithPrices[index].output_price, - per_msg_price: modelsWithPrices[index].per_msg_price, - updated_at: modelsWithPrices[index].updated_at, - }; - }); + return { + id: id, + base_model_id: baseModelId, + name: apiModel.name, + imageUrl: apiModel.imageUrl, + system_prompt: apiModel.system_prompt, + input_price: dbModel.input_price, + output_price: dbModel.output_price, + per_msg_price: dbModel.per_msg_price, + updated_at: dbModel.updated_at, + }; + } + ); return NextResponse.json(validModels); } catch (error) { @@ -126,8 +149,12 @@ export async function GET() { } } -// Add inlet endpoint export async function POST(req: Request) { + const authError = verifyApiToken(req); + if (authError) { + return authError; + } + const data = await req.json(); return new Response("Inlet placeholder response", { @@ -135,10 +162,13 @@ export async function POST(req: Request) { }); } -// Add outlet endpoint export async function PUT(req: Request) { + const authError = verifyApiToken(req); + if (authError) { + return authError; + } + const data = await req.json(); - // console.log("Outlet received:", JSON.stringify(data, null, 2)); return new Response("Outlet placeholder response", { headers: { "Content-Type": "application/json" }, diff --git a/app/api/v1/models/sync-all-prices/route.ts b/app/api/v1/models/sync-all-prices/route.ts index 7bf9d4abdd70d427de516477f7e090cfc3548194..447402e798831e9693c406b3930b9dcb1e5b0544 100644 --- a/app/api/v1/models/sync-all-prices/route.ts +++ b/app/api/v1/models/sync-all-prices/route.ts @@ -1,11 +1,16 @@ import { NextRequest, NextResponse } from "next/server"; -import { pool } from "@/lib/db"; +import { pool } from "@/lib/db/client"; +import { verifyApiToken } from "@/lib/auth"; export async function POST(request: NextRequest) { + const authError = verifyApiToken(request); + if (authError) { + return authError; + } + try { const client = await pool.connect(); try { - // 1. 获取所有有效的派生模型(base_model_id 存在且在数据库中有对应记录) const derivedModelsResult = await client.query(` SELECT d.id, d.name, d.base_model_id FROM model_prices d @@ -24,10 +29,8 @@ export async function POST(request: NextRequest) { const derivedModels = derivedModelsResult.rows; const syncResults = []; - // 2. 为每个派生模型同步价格 for (const derivedModel of derivedModels) { try { - // 获取上游模型价格 const baseModelResult = await client.query( `SELECT input_price, output_price, per_msg_price FROM model_prices WHERE id = $1`, [derivedModel.base_model_id] @@ -45,7 +48,6 @@ export async function POST(request: NextRequest) { const baseModel = baseModelResult.rows[0]; - // 更新派生模型价格 const updateResult = await client.query( `UPDATE model_prices SET diff --git a/app/api/v1/models/sync-price/route.ts b/app/api/v1/models/sync-price/route.ts index 031456a76b4a7746a59df04679ed3a474e7757da..767393c1b073b6d86444ddc38a493465e7024b85 100644 --- a/app/api/v1/models/sync-price/route.ts +++ b/app/api/v1/models/sync-price/route.ts @@ -1,7 +1,13 @@ import { NextRequest, NextResponse } from "next/server"; -import { pool } from "@/lib/db"; +import { pool } from "@/lib/db/client"; +import { verifyApiToken } from "@/lib/auth"; export async function POST(request: NextRequest) { + const authError = verifyApiToken(request); + if (authError) { + return authError; + } + try { const data = await request.json(); const { modelId } = data; @@ -15,7 +21,6 @@ export async function POST(request: NextRequest) { const client = await pool.connect(); try { - // 1. 获取派生模型信息 const derivedModelResult = await client.query( `SELECT id, name, base_model_id FROM model_prices WHERE id = $1`, [modelId] @@ -28,13 +33,11 @@ export async function POST(request: NextRequest) { const derivedModel = derivedModelResult.rows[0]; let baseModelId = derivedModel.base_model_id; - // 如果数据库中没有base_model_id,尝试从ID中提取 if (!baseModelId) { const idParts = modelId.split("."); if (idParts.length > 1) { baseModelId = idParts[idParts.length - 1]; - // 更新数据库中的base_model_id await client.query( `UPDATE model_prices SET base_model_id = $2 WHERE id = $1`, [modelId, baseModelId] @@ -49,7 +52,6 @@ export async function POST(request: NextRequest) { ); } - // 2. 获取上游模型价格 const baseModelResult = await client.query( `SELECT input_price, output_price, per_msg_price FROM model_prices WHERE id = $1`, [baseModelId] @@ -64,7 +66,6 @@ export async function POST(request: NextRequest) { const baseModel = baseModelResult.rows[0]; - // 3. 更新派生模型价格 const updateResult = await client.query( `UPDATE model_prices SET diff --git a/app/api/v1/models/test/route.ts b/app/api/v1/models/test/route.ts index 5a6e89a5f3f72917fd3587236b50894908f4207c..4536a36aafc603fa41b28a6b073caf5aca531e3a 100644 --- a/app/api/v1/models/test/route.ts +++ b/app/api/v1/models/test/route.ts @@ -1,6 +1,12 @@ import { NextResponse } from "next/server"; +import { verifyApiToken } from "@/lib/auth"; export async function POST(req: Request) { + const authError = verifyApiToken(req); + if (authError) { + return authError; + } + try { const { modelId } = await req.json(); diff --git a/app/api/v1/outlet/route.ts b/app/api/v1/outlet/route.ts index 626f25395f6418b61d9a8fcb0dfebe0462713042..ec0a07a2a9b151b435a18e0de39e210a50671fa1 100644 --- a/app/api/v1/outlet/route.ts +++ b/app/api/v1/outlet/route.ts @@ -34,7 +34,6 @@ async function getModelPrice(modelId: string): Promise { return result.rows[0]; } - // If no price is found in the database, use the default price const defaultInputPrice = parseFloat( process.env.DEFAULT_MODEL_INPUT_PRICE || "60" ); @@ -42,7 +41,6 @@ async function getModelPrice(modelId: string): Promise { process.env.DEFAULT_MODEL_OUTPUT_PRICE || "60" ); - // Verify that the default price is a valid non-negative number if ( isNaN(defaultInputPrice) || defaultInputPrice < 0 || @@ -57,7 +55,7 @@ async function getModelPrice(modelId: string): Promise { name: modelId, input_price: defaultInputPrice, output_price: defaultOutputPrice, - per_msg_price: -1, // Default to token-based pricing + per_msg_price: -1, }; } @@ -66,7 +64,6 @@ export async function POST(req: Request) { let pgClient: DbClient | null = null; try { - // Get a dedicated transaction client if (isVercel) { pgClient = client; } else { @@ -74,21 +71,18 @@ export async function POST(req: Request) { } const data = await req.json(); - console.log("请求数据:", JSON.stringify(data, null, 2)); + console.log("Request data:", JSON.stringify(data, null, 2)); const modelId = data.body.model; const userId = data.user.id; const userName = data.user.name || "Unknown User"; - // Start a transaction await query("BEGIN"); - // Get model price const modelPrice = await getModelPrice(modelId); if (!modelPrice) { throw new Error(`Fail to fetch price info of model ${modelId}`); } - // Calculate tokens const lastMessage = data.body.messages[data.body.messages.length - 1]; let inputTokens: number; @@ -109,32 +103,25 @@ export async function POST(req: Request) { inputTokens = totalTokens - outputTokens; } - // Calculate total cost let totalCost: number; if (outputTokens === 0) { - // If output tokens are 0, no charge totalCost = 0; console.log("No charge for zero output tokens"); } else if (modelPrice.per_msg_price >= 0) { - // If fixed pricing is set, use it directly totalCost = Number(modelPrice.per_msg_price); console.log( `Using fixed pricing: ${totalCost} (${modelPrice.per_msg_price} per message)` ); } else { - // Otherwise, calculate price by token quantity const inputCost = (inputTokens / 1_000_000) * modelPrice.input_price; const outputCost = (outputTokens / 1_000_000) * modelPrice.output_price; totalCost = inputCost + outputCost; } - // Get the pre-deducted cost when getting inlet const inletCost = getModelInletCost(modelId); - // The actual cost to be deducted = total cost - pre-deducted cost const actualCost = totalCost - inletCost; - // Get and update user balance const userResult = await query( `UPDATE users SET balance = LEAST( @@ -156,7 +143,6 @@ export async function POST(req: Request) { throw new Error("Balance exceeds maximum allowed value"); } - // Record usage await query( `INSERT INTO user_usage_records ( user_id, nickname, model_name, @@ -208,7 +194,6 @@ export async function POST(req: Request) { { status: 500 } ); } finally { - // Only release connection in non-Vercel environment if (!isVercel && pgClient && "release" in pgClient) { pgClient.release(); } diff --git a/app/api/v1/panel/database/export/route.ts b/app/api/v1/panel/database/export/route.ts index f8d758fd6bd4e08e05980139f48f2021e348422a..746225f8b57cb7d9d6e7b2f8c1bfb5878ae6e4d4 100644 --- a/app/api/v1/panel/database/export/route.ts +++ b/app/api/v1/panel/database/export/route.ts @@ -1,24 +1,18 @@ -import { pool } from "@/lib/db"; +import { query } from "@/lib/db/client"; import { NextResponse } from "next/server"; -import { PoolClient } from "pg"; +import { verifyApiToken } from "@/lib/auth"; -export async function GET() { - let client: PoolClient | null = null; +export async function GET(req: Request) { + const authError = verifyApiToken(req); + if (authError) { + return authError; + } try { - // 获取数据库连接 - client = await pool.connect(); - - // 获取所有表的数据 - const users = await client.query("SELECT * FROM users ORDER BY id"); - const modelPrices = await client.query( - "SELECT * FROM model_prices ORDER BY id" - ); - const records = await client.query( - "SELECT * FROM user_usage_records ORDER BY id" - ); + const users = await query("SELECT * FROM users ORDER BY id"); + const modelPrices = await query("SELECT * FROM model_prices ORDER BY id"); + const records = await query("SELECT * FROM user_usage_records ORDER BY id"); - // 构建导出数据结构 const exportData = { version: "1.0", timestamp: new Date().toISOString(), @@ -29,7 +23,6 @@ export async function GET() { }, }; - // 设置响应头 const headers = new Headers(); headers.set("Content-Type", "application/json"); headers.set( @@ -48,9 +41,5 @@ export async function GET() { { error: "Fail to export database" }, { status: 500 } ); - } finally { - if (client) { - client.release(); - } } } diff --git a/app/api/v1/panel/database/import/route.ts b/app/api/v1/panel/database/import/route.ts index 9f08f6a8eb0b1479b2e8e9b8bbc3040f071e40d4..d53734182ff47131090476e4b1af5e2969ee40ef 100644 --- a/app/api/v1/panel/database/import/route.ts +++ b/app/api/v1/panel/database/import/route.ts @@ -1,34 +1,30 @@ -import { pool } from "@/lib/db"; +import { query } from "@/lib/db/client"; import { NextResponse } from "next/server"; -import { PoolClient } from "pg"; +import { verifyApiToken } from "@/lib/auth"; export async function POST(req: Request) { - let client: PoolClient | null = null; + const authError = verifyApiToken(req); + if (authError) { + return authError; + } try { const data = await req.json(); - // 验证数据格式 if (!data.version || !data.data) { throw new Error("Invalid import data format"); } - // 获取数据库连接 - client = await pool.connect(); - - // 开启事务 - await client.query("BEGIN"); - try { - // 清空现有数据 - await client.query("TRUNCATE TABLE user_usage_records CASCADE"); - await client.query("TRUNCATE TABLE model_prices CASCADE"); - await client.query("TRUNCATE TABLE users CASCADE"); + await query("BEGIN"); + + await query("TRUNCATE TABLE user_usage_records CASCADE"); + await query("TRUNCATE TABLE model_prices CASCADE"); + await query("TRUNCATE TABLE users CASCADE"); - // 导入用户数据 if (data.data.users?.length) { for (const user of data.data.users) { - await client.query( + await query( `INSERT INTO users (id, email, name, role, balance) VALUES ($1, $2, $3, $4, $5)`, [user.id, user.email, user.name, user.role, user.balance] @@ -36,10 +32,9 @@ export async function POST(req: Request) { } } - // 导入模型价格 if (data.data.model_prices?.length) { for (const price of data.data.model_prices) { - await client.query( + await query( `INSERT INTO model_prices (id, name, input_price, output_price) VALUES ($1, $2, $3, $4)`, [price.id, price.name, price.input_price, price.output_price] @@ -47,10 +42,9 @@ export async function POST(req: Request) { } } - // 导入使用记录 if (data.data.user_usage_records?.length) { for (const record of data.data.user_usage_records) { - await client.query( + await query( `INSERT INTO user_usage_records ( user_id, nickname, use_time, model_name, input_tokens, output_tokens, cost, balance_after @@ -69,14 +63,14 @@ export async function POST(req: Request) { } } - await client.query("COMMIT"); + await query("COMMIT"); return NextResponse.json({ success: true, message: "Data import successful", }); } catch (error) { - await client.query("ROLLBACK"); + await query("ROLLBACK"); throw error; } } catch (error) { @@ -89,9 +83,5 @@ export async function POST(req: Request) { }, { status: 500 } ); - } finally { - if (client) { - client.release(); - } } } diff --git a/app/api/v1/panel/records/export/route.ts b/app/api/v1/panel/records/export/route.ts index 622608b129ff3a316d178870d5fa58b3ac7a8cff..f3795099f60030896e586cfed4966d41f5f14909 100644 --- a/app/api/v1/panel/records/export/route.ts +++ b/app/api/v1/panel/records/export/route.ts @@ -1,13 +1,15 @@ -import { pool } from "@/lib/db"; +import { query } from "@/lib/db/client"; import { NextResponse } from "next/server"; -import { PoolClient } from "pg"; +import { verifyApiToken } from "@/lib/auth"; -export async function GET() { - let client: PoolClient | null = null; - try { - client = await pool.connect(); +export async function GET(req: Request) { + const authError = verifyApiToken(req); + if (authError) { + return authError; + } - const records = await client.query(` + try { + const records = await query(` SELECT nickname, use_time, @@ -20,7 +22,6 @@ export async function GET() { ORDER BY use_time DESC `); - // 生成 CSV 内容 const csvHeaders = [ "User", "Time", @@ -45,7 +46,6 @@ export async function GET() { ...rows.map((row) => row.join(",")), ].join("\n"); - // 设置响应头 const responseHeaders = new Headers(); responseHeaders.set("Content-Type", "text/csv; charset=utf-8"); responseHeaders.set( @@ -62,9 +62,5 @@ export async function GET() { { error: "Fail to export records" }, { status: 500 } ); - } finally { - if (client) { - client.release(); - } } } diff --git a/app/api/v1/panel/records/route.ts b/app/api/v1/panel/records/route.ts index e931d49412be7aaf83694c3c11f55c9a05c5ea54..bf3d2ababfc7cceff14cefab5f6d76ec42d2348f 100644 --- a/app/api/v1/panel/records/route.ts +++ b/app/api/v1/panel/records/route.ts @@ -1,9 +1,13 @@ -import { pool } from "@/lib/db"; +import { query } from "@/lib/db/client"; import { NextResponse } from "next/server"; -import { PoolClient } from "pg"; +import { verifyApiToken } from "@/lib/auth"; export async function GET(req: Request) { - let client: PoolClient | null = null; + const authError = verifyApiToken(req); + if (authError) { + return authError; + } + try { const { searchParams } = new URL(req.url); const page = parseInt(searchParams.get("page") || "1"); @@ -12,10 +16,9 @@ export async function GET(req: Request) { const sortOrder = searchParams.get("sortOrder"); const users = searchParams.get("users")?.split(",") || []; const models = searchParams.get("models")?.split(",") || []; + const startDate = searchParams.get("startDate"); + const endDate = searchParams.get("endDate"); - client = await pool.connect(); - - // 构建查询条件 const conditions = []; const params = []; let paramIndex = 1; @@ -32,23 +35,29 @@ export async function GET(req: Request) { paramIndex++; } + if (startDate && endDate) { + conditions.push( + `use_time >= $${paramIndex} AND use_time <= $${paramIndex + 1}` + ); + params.push(startDate); + params.push(endDate); + paramIndex += 2; + } + const whereClause = conditions.length > 0 ? `WHERE ${conditions.join(" AND ")}` : ""; - // 构建排序 const orderClause = sortField ? `ORDER BY ${sortField} ${sortOrder === "descend" ? "DESC" : "ASC"}` : "ORDER BY use_time DESC"; - // 获取总记录数 const countQuery = ` SELECT COUNT(*) FROM user_usage_records ${whereClause} `; - const countResult = await client.query(countQuery, params); + const countResult = await query(countQuery, params); - // 获取分页数据 const offset = (page - 1) * pageSize; const dataQuery = ` SELECT @@ -67,7 +76,7 @@ export async function GET(req: Request) { `; const dataParams = [...params, pageSize, offset]; - const records = await client.query(dataQuery, dataParams); + const records = await query(dataQuery, dataParams); const total = parseInt(countResult.rows[0].count); @@ -81,9 +90,5 @@ export async function GET(req: Request) { { error: "Fail to fetch usage records" }, { status: 500 } ); - } finally { - if (client) { - client.release(); - } } } diff --git a/app/api/v1/panel/usage/route.ts b/app/api/v1/panel/usage/route.ts index 7a3b47533343c204c7cd3cbf5ac65067c44a51b8..40ca73411b5976d2f44c9f5fcc430dce890aff2c 100644 --- a/app/api/v1/panel/usage/route.ts +++ b/app/api/v1/panel/usage/route.ts @@ -1,12 +1,20 @@ import { NextResponse } from "next/server"; -import { pool } from "@/lib/db"; +import { query } from "@/lib/db/client"; +import { verifyApiToken } from "@/lib/auth"; export async function GET(request: Request) { + const authError = verifyApiToken(request); + if (authError) { + return authError; + } + try { const { searchParams } = new URL(request.url); const startTime = searchParams.get("startTime"); const endTime = searchParams.get("endTime"); + console.log("Query params:", [startTime, endTime]); + const timeFilter = startTime && endTime ? `WHERE use_time >= $1 AND use_time <= $2` : ""; @@ -14,7 +22,7 @@ export async function GET(request: Request) { const [modelResult, userResult, timeRangeResult, statsResult] = await Promise.all([ - pool.query( + query( ` SELECT model_name, @@ -27,7 +35,7 @@ export async function GET(request: Request) { `, params ), - pool.query( + query( ` SELECT nickname, @@ -40,13 +48,13 @@ export async function GET(request: Request) { `, params ), - pool.query(` + query(` SELECT MIN(use_time) as min_time, MAX(use_time) as max_time FROM user_usage_records `), - pool.query( + query( ` SELECT COALESCE(SUM(input_tokens + output_tokens), 0) as total_tokens, @@ -82,6 +90,9 @@ export async function GET(request: Request) { return NextResponse.json(formattedData); } catch (error) { console.error("Fail to fetch usage records:", error); + if (error instanceof Error) { + console.error("[DB Query Error]", error); + } return NextResponse.json( { error: "Fail to fetch usage records" }, { status: 500 } diff --git a/app/api/users/[id]/balance/route.ts b/app/api/v1/users/[id]/balance/route.ts similarity index 75% rename from app/api/users/[id]/balance/route.ts rename to app/api/v1/users/[id]/balance/route.ts index da2fe1b4e78e4a8a5ed2c5e2ad4ab02edea83de8..7866e84f61c6f314efcc2c40152c6f7f43a02320 100644 --- a/app/api/users/[id]/balance/route.ts +++ b/app/api/v1/users/[id]/balance/route.ts @@ -1,17 +1,25 @@ import { query } from "@/lib/db/client"; import { NextResponse } from "next/server"; +import { verifyApiToken } from "@/lib/auth"; export async function PUT( req: Request, { params }: { params: { id: string } } ) { + const authError = verifyApiToken(req); + if (authError) { + return authError; + } + try { const { balance } = await req.json(); const userId = params.id; + console.log(`Updating balance for user ${userId} to ${balance}`); + if (typeof balance !== "number") { return NextResponse.json( - { error: "Balance must be positive" }, + { error: "Balance must be a number" }, { status: 400 } ); } @@ -24,6 +32,8 @@ export async function PUT( [balance, userId] ); + console.log(`Update result:`, result); + if (result.rows.length === 0) { return NextResponse.json( { error: "User does not exist" }, diff --git a/app/api/users/[id]/route.ts b/app/api/v1/users/[id]/route.ts similarity index 84% rename from app/api/users/[id]/route.ts rename to app/api/v1/users/[id]/route.ts index f4697dc5db1986b516b2d52892943a693b4218f4..69c95f62a177e82c3c3254bfda3cd567f45c0e31 100644 --- a/app/api/users/[id]/route.ts +++ b/app/api/v1/users/[id]/route.ts @@ -1,11 +1,17 @@ import { NextRequest, NextResponse } from "next/server"; import { deleteUser } from "@/lib/db/users"; import { query } from "@/lib/db/client"; +import { verifyApiToken } from "@/lib/auth"; export async function DELETE( req: NextRequest, { params }: { params: { id: string } } ) { + const authError = verifyApiToken(req); + if (authError) { + return authError; + } + try { await deleteUser(params.id); return NextResponse.json({ success: true }); @@ -19,6 +25,11 @@ export async function PATCH( req: NextRequest, { params }: { params: { id: string } } ) { + const authError = verifyApiToken(req); + if (authError) { + return authError; + } + try { const { deleted } = await req.json(); diff --git a/app/api/users/route.ts b/app/api/v1/users/route.ts similarity index 93% rename from app/api/users/route.ts rename to app/api/v1/users/route.ts index f6a30cf69e265534bae26a8c67477dea0e3d3d54..4187890fc2f8c4d3b923a5c1841ed318f6727129 100644 --- a/app/api/users/route.ts +++ b/app/api/v1/users/route.ts @@ -1,10 +1,15 @@ import { NextRequest, NextResponse } from "next/server"; import { query } from "@/lib/db/client"; import { ensureUserTableExists } from "@/lib/db/users"; +import { verifyApiToken } from "@/lib/auth"; export async function GET(req: NextRequest) { + const authError = verifyApiToken(req); + if (authError) { + return authError; + } + try { - // 确保表结构正确 await ensureUserTableExists(); const { searchParams } = new URL(req.url); @@ -15,7 +20,6 @@ export async function GET(req: NextRequest) { const search = searchParams.get("search"); const deleted = searchParams.get("deleted") === "true"; - // 构建查询条件 const conditions = [`deleted = ${deleted}`]; const params = []; let paramIndex = 1; @@ -30,14 +34,12 @@ export async function GET(req: NextRequest) { const whereClause = `WHERE ${conditions.join(" AND ")}`; - // 获取总记录数 const countResult = await query( `SELECT COUNT(*) FROM users ${whereClause}`, params ); const total = parseInt(countResult.rows[0].count); - // 获取分页数据 const result = await query( `SELECT id, email, name, role, balance, deleted, created_at FROM users diff --git a/app/apple-icon.png b/app/apple-icon.png index ed203ef5e88beab19b3dd07dfad56d8b9917fe66..0a4d8c0a24b0d16ccaee8b4f65155715d8511a14 100644 Binary files a/app/apple-icon.png and b/app/apple-icon.png differ diff --git a/app/globals.css b/app/globals.css index 3e8008e4350414d4a7d61cd5d7c3633b62e4be47..e9f48c7af8231e909be1bee4d0812609af4cdefa 100644 --- a/app/globals.css +++ b/app/globals.css @@ -80,7 +80,6 @@ body { display: none; } -/* 更新模态框样式 */ .update-modal .ant-modal-content { padding: 24px; border-radius: 16px; @@ -222,7 +221,6 @@ body { background-size: 24px 24px; } -/* 添加以下样式 */ @media (max-width: 640px) { .toaster-group { --viewport-padding: 16px; @@ -240,7 +238,6 @@ body { } } -/* 自定义日期选择器样式 */ .custom-date-picker { border-radius: 0.5rem; box-shadow: 0 10px 15px -3px rgb(0 0 0 / 0.1); diff --git a/app/icon.png b/app/icon.png index ed203ef5e88beab19b3dd07dfad56d8b9917fe66..0a4d8c0a24b0d16ccaee8b4f65155715d8511a14 100644 Binary files a/app/icon.png and b/app/icon.png differ diff --git a/app/models/page.tsx b/app/models/page.tsx index 86dfc108b7fe349740153b13c5d6ae5e62e56583..3b145b5d4a3b7df13871afad585c2be3b7c1d600 100644 --- a/app/models/page.tsx +++ b/app/models/page.tsx @@ -263,7 +263,12 @@ export default function ModelsPage() { useEffect(() => { const fetchModels = async () => { try { - const response = await fetch("/api/v1/models"); + const token = localStorage.getItem("access_token"); + const response = await fetch("/api/v1/models", { + headers: { + Authorization: `Bearer ${token}`, + }, + }); if (!response.ok) { throw new Error(t("error.model.failToFetchModels")); } @@ -291,7 +296,12 @@ export default function ModelsPage() { useEffect(() => { const fetchApiKey = async () => { try { - const response = await fetch("/api/config/key"); + const token = localStorage.getItem("access_token"); + const response = await fetch("/api/v1/config/key", { + headers: { + Authorization: `Bearer ${token}`, + }, + }); if (!response.ok) { throw new Error( `${t("error.model.failToFetchApiKey")}: ${response.status}` @@ -342,9 +352,13 @@ export default function ModelsPage() { const per_msg_price = field === "per_msg_price" ? validValue : model.per_msg_price; + const token = localStorage.getItem("access_token"); const response = await fetch("/api/v1/models/price", { method: "POST", - headers: { "Content-Type": "application/json" }, + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${token}`, + }, body: JSON.stringify({ updates: [ { @@ -416,10 +430,12 @@ export default function ModelsPage() { try { setSyncing(true); + const token = localStorage.getItem("access_token"); const response = await fetch("/api/v1/models/sync-all-prices", { method: "POST", headers: { "Content-Type": "application/json", + Authorization: `Bearer ${token}`, }, }); @@ -429,7 +445,6 @@ export default function ModelsPage() { throw new Error(data.error || t("models.syncFail")); } - // 更新模型数据 if (data.syncedModels && data.syncedModels.length > 0) { setModels((prev) => prev.map((model) => { @@ -655,11 +670,12 @@ export default function ModelsPage() { const timeoutId = setTimeout(() => controller.abort(), 30000); try { + const token = localStorage.getItem("access_token"); const response = await fetch("/api/v1/models/test", { method: "POST", headers: { "Content-Type": "application/json", - Authorization: `Bearer ${apiKey}`, + Authorization: `Bearer ${token}`, }, body: JSON.stringify({ modelId: model.id, @@ -725,7 +741,6 @@ export default function ModelsPage() { } }; - // 修改表格样式 const tableClassName = ` [&_.ant-table]:!border-b-0 [&_.ant-table-container]:!rounded-xl @@ -746,7 +761,6 @@ export default function ModelsPage() { [&_.ant-table-cell:last-child]:!pr-6 `; - // 修改移动端卡片组件 const MobileCard = ({ record }: { record: Model }) => { const isPerMsgEnabled = record.per_msg_price >= 0; @@ -821,7 +835,6 @@ export default function ModelsPage() { ); }; - // 将 renderPriceCell 修改为接收一个额外的 showTooltip 参数 const renderPriceCell = ( field: "input_price" | "output_price" | "per_msg_price", record: Model, @@ -841,9 +854,7 @@ export default function ModelsPage() { try { await handlePriceUpdate(record.id, field, value); setEditingCell(null); - } catch { - // 错误已经在 handlePriceUpdate 中处理 - } + } catch {} }} t={t} disabled={isDisabled} @@ -875,7 +886,6 @@ export default function ModelsPage() { return (
- {/* 添加 Toaster 组件 */} - {/* 页面标题部分 */}

{t("models.title")} @@ -892,7 +901,6 @@ export default function ModelsPage() {

{t("models.description")}

- {/* 操作按钮组 */}
- {/* 替换原有的TestProgress组件 */} - {/* 桌面端表格视图 */}
{loading ? ( @@ -1034,7 +1040,6 @@ export default function ModelsPage() {
- {/* 移动端卡片视图 */}
{loading ? ( diff --git a/app/page.tsx b/app/page.tsx index 4d9ef6f101a8dfcc8f98c930a988ac3d28741525..631d6ede93fbc9af960bff91ab93de18eb15b0cb 100644 --- a/app/page.tsx +++ b/app/page.tsx @@ -31,7 +31,6 @@ export default function HomePage() { const currentVer = APP_VERSION.replace(/^v/, ""); const newVer = latestVer.replace(/^v/, ""); - // 检查是否有更新且用户未禁用该版本的提示 const ignoredVersion = localStorage.getItem("ignoredVersion"); if (currentVer !== newVer && ignoredVersion !== latestVer) { setLatestVersion(latestVer); @@ -61,7 +60,6 @@ export default function HomePage() { return (
- {/* 新增动态网格背景 */} - {/* 装饰性背景模糊圆 */}
- {/* 修改主要内容容器,使用flex布局固定GitHub在底部 */} - {/* 标题区域保持不变 */} - {/* 重新设计的导航区域 */}
- {/* 装饰性背景 */}
- {/* 新的垂直导航设计 */}
{[ { @@ -164,15 +156,12 @@ export default function HomePage() { shadow-[0_4px_20px_-4px_rgba(0,0,0,0.05)] hover:shadow-[0_8px_30px_-4px_rgba(0,0,0,0.12)]" > - {/* 移除边框设计,只保留渐变背景 */}
- {/* 内容区域 */}
- {/* 图标容器 - 使用投影替代边框 */}
- {/* 文字内容 */}

{item.title} @@ -195,7 +183,6 @@ export default function HomePage() {

- {/* 箭头 */}
- {/* GitHub 图标固定在底部 */} - {/* 更新提示框样式修改 */} {isUpdateVisible && ( ([ - new Date(), // 将在加载数据后更新 + new Date(), new Date(), ]); const [availableTimeRange, setAvailableTimeRange] = useState<{ @@ -109,18 +109,26 @@ export default function PanelPage() { const fetchUsageData = async (range: [Date, Date]) => { setLoading(true); try { - const startTime = dayjs(range[0]).startOf("day").toISOString(); - const endTime = dayjs(range[1]).endOf("day").toISOString(); + const startTime = dayjs(range[0]) + .startOf("day") + .format("YYYY-MM-DDTHH:mm:ssZ"); + const endTime = dayjs(range[1]) + .endOf("day") + .format("YYYY-MM-DDTHH:mm:ssZ"); const url = `/api/v1/panel/usage?startTime=${startTime}&endTime=${endTime}`; - const response = await fetch(url); + const token = localStorage.getItem("access_token"); + const response = await fetch(url, { + headers: { + Authorization: `Bearer ${token}`, + }, + }); if (!response.ok) throw new Error("Failed to fetch data"); const data = await response.json(); setUsageData(data); - // 如果是全部时间范围,更新可用时间范围 if (timeRangeType === "all") { const minTime = new Date(data.timeRange.minTime); const maxTime = new Date(data.timeRange.maxTime); @@ -143,7 +151,6 @@ export default function PanelPage() { params.pagination.pageSize?.toString() || "10" ); - // 添加排序和过滤参数 if (params.sortField) { searchParams.append("sortField", params.sortField); searchParams.append("sortOrder", params.sortOrder || "ascend"); @@ -155,18 +162,23 @@ export default function PanelPage() { searchParams.append("models", params.filters.model_name.join(",")); } - // 添加日期范围 searchParams.append( "startDate", - dayjs(range[0]).startOf("day").format("YYYY-MM-DD") + dayjs(range[0]).startOf("day").format("YYYY-MM-DDTHH:mm:ssZ") ); searchParams.append( "endDate", - dayjs(range[1]).endOf("day").format("YYYY-MM-DD") + dayjs(range[1]).endOf("day").format("YYYY-MM-DDTHH:mm:ssZ") ); + const token = localStorage.getItem("access_token"); const response = await fetch( - `/api/v1/panel/records?${searchParams.toString()}` + `/api/v1/panel/records?${searchParams.toString()}`, + { + headers: { + Authorization: `Bearer ${token}`, + }, + } ); const data = await response.json(); @@ -187,15 +199,18 @@ export default function PanelPage() { useEffect(() => { const loadInitialData = async () => { - // 获取全部时间范围的数据 - const response = await fetch("/api/v1/panel/usage"); + const token = localStorage.getItem("access_token"); + const response = await fetch("/api/v1/panel/usage", { + headers: { + Authorization: `Bearer ${token}`, + }, + }); const data = await response.json(); const minTime = dayjs(data.timeRange.minTime).startOf("day").toDate(); const maxTime = dayjs(data.timeRange.maxTime).endOf("day").toDate(); setAvailableTimeRange({ minTime, maxTime }); - // 设置为全部时间范围 const allTimeRange: [Date, Date] = [minTime, maxTime]; setDateRange(allTimeRange); setTimeRangeType("all"); @@ -237,11 +252,9 @@ export default function PanelPage() { }; const renderDateRangeLabel = () => { - // 如果是同一天,只显示一个日期 if (dayjs(dateRange[0]).isSame(dateRange[1], "day")) { return dayjs(dateRange[0]).format("YYYY-MM-DD"); } - // 否则显示日期范围 return `${dayjs(dateRange[0]).format("YYYY-MM-DD")} ~ ${dayjs( dateRange[1] ).format("YYYY-MM-DD")}`; diff --git a/app/records/page.tsx b/app/records/page.tsx index 0daf5f1eb82979231ed6b08d41505e125380eeb2..ce7953261138be53fd2218e0491e11f861eaff73 100644 --- a/app/records/page.tsx +++ b/app/records/page.tsx @@ -177,7 +177,6 @@ export default function RecordsPage() { }, }); - // 设置筛选选项 setUsers(data.users as string[]); setModels(data.models as string[]); } catch (error) { @@ -230,7 +229,6 @@ export default function RecordsPage() { } }; - // 修改表格样式 const tableClassName = ` [&_.ant-table]:!border-b-0 [&_.ant-table-container]:!rounded-xl diff --git a/app/token/page.tsx b/app/token/page.tsx index 8f5db25c03afac49ce488a23b3a822f5a02faffc..7092c90a2b9185215b6ce00f49b6fbb36f951533 100644 --- a/app/token/page.tsx +++ b/app/token/page.tsx @@ -34,7 +34,7 @@ export default function TokenPage() { setLoading(true); try { localStorage.setItem("access_token", token); - const res = await fetch("/api/config", { + const res = await fetch("/api/v1/config", { headers: { Authorization: `Bearer ${token}`, }, @@ -76,7 +76,6 @@ export default function TokenPage() { )} /> - {/* 装饰性背景模糊圆 */}
@@ -112,7 +111,6 @@ export default function TokenPage() { className="backdrop-blur-[20px] bg-white/[0.08] p-10 rounded-[2.5rem] border border-white/20 shadow-2xl relative overflow-hidden hover:shadow-[0_8px_60px_rgba(120,119,198,0.15)] transition-shadow duration-300 group" > - {/* 新增流光边框效果 */}
- {/* 新增动态粒子背景 */}
(
@@ -310,7 +309,7 @@ export default function UsersPage() { const fetchUsers = async (page: number, isBlacklist: boolean = false) => { setLoading(true); try { - let url = `/api/users?page=${page}&deleted=${isBlacklist}`; + let url = `/api/v1/users?page=${page}&deleted=${isBlacklist}`; if (sortInfo.field && sortInfo.order) { url += `&sortField=${sortInfo.field}&sortOrder=${sortInfo.order}`; } @@ -318,7 +317,12 @@ export default function UsersPage() { url += `&search=${encodeURIComponent(searchText)}`; } - const res = await fetch(url); + const token = localStorage.getItem("access_token"); + const res = await fetch(url, { + headers: { + Authorization: `Bearer ${token}`, + }, + }); const data = await res.json(); if (!res.ok) throw new Error(data.error); @@ -339,7 +343,12 @@ export default function UsersPage() { const fetchBlacklistTotal = async () => { try { - const res = await fetch(`/api/users?page=1&deleted=true&pageSize=1`); + const token = localStorage.getItem("access_token"); + const res = await fetch(`/api/v1/users?page=1&deleted=true&pageSize=1`, { + headers: { + Authorization: `Bearer ${token}`, + }, + }); const data = await res.json(); if (!res.ok) throw new Error(data.error); setBlacklistTotal(data.total); @@ -361,29 +370,41 @@ export default function UsersPage() { const handleUpdateBalance = async (userId: string, newBalance: number) => { try { - const res = await fetch(`/api/users/${userId}/balance`, { + console.log(`Updating balance for user ${userId} to ${newBalance}`); + + const token = localStorage.getItem("access_token"); + if (!token) { + throw new Error(t("auth.unauthorized")); + } + + const res = await fetch(`/api/v1/users/${userId}/balance`, { method: "PUT", - headers: { "Content-Type": "application/json" }, + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${token}`, + }, body: JSON.stringify({ balance: newBalance }), }); const data = await res.json(); - if (!res.ok) throw new Error(data.error); + console.log("Update balance response:", data); + + if (!res.ok) { + throw new Error(data.error || t("users.message.updateBalance.error")); + } - // 立即更新本地数据 setUsers( users.map((user) => user.id === userId ? { ...user, balance: newBalance } : user ) ); - // 然后再显示成功提示并清除编辑状态 toast.success(t("users.message.updateBalance.success")); setEditingKey(""); - // 最后再重新获取完整列表 fetchUsers(currentPage, false); } catch (err) { + console.error("Failed to update balance:", err); toast.error( err instanceof Error ? err.message @@ -396,10 +417,12 @@ export default function UsersPage() { if (!userToDelete) return; try { - const res = await fetch(`/api/users/${userToDelete.id}`, { + const token = localStorage.getItem("access_token"); + const res = await fetch(`/api/v1/users/${userToDelete.id}`, { method: "PATCH", headers: { "Content-Type": "application/json", + Authorization: `Bearer ${token}`, }, body: JSON.stringify({ deleted: !userToDelete.deleted, @@ -782,7 +805,6 @@ export default function UsersPage() { ); }; - // 添加空状态组件 const EmptyState = ({ searchText }: { searchText: string }) => (
diff --git a/components/AuthCheck.tsx b/components/AuthCheck.tsx index ac8bcd6efec9c14d9ea2d3cc2e3ec2339d517df5..4baac936fa53439fc3aaeec658a9da130a5b4d31 100644 --- a/components/AuthCheck.tsx +++ b/components/AuthCheck.tsx @@ -9,9 +9,20 @@ export default function AuthCheck({ children }: { children: React.ReactNode }) { const router = useRouter(); const pathname = usePathname(); + useEffect(() => { + const initDb = async () => { + try { + await fetch("/api/init"); + } catch (error) { + console.error("初始化数据库失败:", error); + } + }; + + initDb(); + }, []); + useEffect(() => { const checkAuth = async () => { - // 如果已经在token页面,不需要检查 if (pathname === "/token") { setIsLoading(false); setIsAuthorized(true); @@ -25,7 +36,7 @@ export default function AuthCheck({ children }: { children: React.ReactNode }) { } try { - const res = await fetch("/api/config", { + const res = await fetch("/api/v1/config", { headers: { Authorization: `Bearer ${token}`, }, @@ -49,7 +60,6 @@ export default function AuthCheck({ children }: { children: React.ReactNode }) { checkAuth(); }, [router, pathname]); - // 显示加载状态或空白页面 if (isLoading || !isAuthorized) { return null; } diff --git a/components/Header.tsx b/components/Header.tsx index 62e26f648ad0571062ca0ad502deff80c48a9b62..09169908dd0fb3491dfb392268d259ec17840e4b 100644 --- a/components/Header.tsx +++ b/components/Header.tsx @@ -40,15 +40,12 @@ export default function Header() { const [isBackupModalOpen, setIsBackupModalOpen] = useState(false); const [isCheckingUpdate, setIsCheckingUpdate] = useState(false); const [accessToken, setAccessToken] = useState(null); - const [isSettingsExpanded, setIsSettingsExpanded] = useState(false); - // 将函数声明移到前面 const handleLanguageChange = async (newLang: string) => { await i18n.changeLanguage(newLang); localStorage.setItem("language", newLang); }; - // 如果是token页面,只显示语言切换按钮 const isTokenPage = pathname === "/token"; if (isTokenPage) { @@ -89,15 +86,13 @@ export default function Header() { return; } - // 验证token的有效性 - fetch("/api/config", { + fetch("/api/v1/config", { headers: { Authorization: `Bearer ${token}`, }, }) .then((res) => { if (!res.ok) { - // 如果token无效,清除token并重定向 localStorage.removeItem("access_token"); router.push("/token"); return; @@ -111,7 +106,6 @@ export default function Header() { }) .catch(() => { setApiKey(t("common.error")); - // 发生错误时也清除token并重定向 localStorage.removeItem("access_token"); router.push("/token"); }); @@ -292,11 +286,10 @@ export default function Header() { ]; const menuItems = [ - // 在小屏幕上将导航项添加到菜单中,但需要特殊处理 ...(!isTokenPage ? navigationItems.map((item) => ({ ...item, - onClick: () => router.push(item.path), // 为导航项添加 onClick 处理 + onClick: () => router.push(item.path), })) : []), { @@ -325,7 +318,6 @@ export default function Header() { }, ]; - // 在navigationItems数组后添加 const actionItems = [ { icon: , @@ -372,9 +364,7 @@ export default function Header() { - {/* 右侧内容 */}
- {/* 导航项 - 仅在大屏幕显示 */} {!isTokenPage && (
{navigationItems.map((item) => ( @@ -401,7 +391,6 @@ export default function Header() {
)} - {/* 语言切换和菜单按钮 */}
{actionItems.map((item, index) => (
- {/* 菜单项列表 */}
- {/* 导航项 - 仅在移动端显示 */}
{navigationItems.map((item, index) => ( - {/* 设置选项 - 直接显示所有选项 */}
{settingsItems.map((item, index) => (
- {/* 桌面设备表格 */}
- {/* 移动设备卡片列表 */}
{loading ? (
diff --git a/components/panel/UserRankingChart.tsx b/components/panel/UserRankingChart.tsx index 36c7cf76db149d5c014da661dc7fe9d41f4a3781..72ab07d1c80efbe94d41d8a5eda8ac1dc8f401c2 100644 --- a/components/panel/UserRankingChart.tsx +++ b/components/panel/UserRankingChart.tsx @@ -225,13 +225,12 @@ export default function UserRankingChart({ const onChartReady = (instance: ECharts) => { chartRef.current = instance; const zoomSize = 6; - let isZoomed = false; // 增加一个状态变量 + let isZoomed = false; instance.on("click", (params) => { const dataLength = users.length; if (!isZoomed) { - // 第一次点击,放大区域 instance.dispatchAction({ type: "dataZoom", startValue: @@ -242,7 +241,6 @@ export default function UserRankingChart({ }); isZoomed = true; } else { - // 第二次点击,还原缩放 instance.dispatchAction({ type: "dataZoom", start: 0, diff --git a/components/ui/animated-grid-pattern.tsx b/components/ui/animated-grid-pattern.tsx index 10b20a90e3c3185871ac7866ae7a05228576035d..b0e8d20e542395102462f50d89edb3eafaf521bf 100644 --- a/components/ui/animated-grid-pattern.tsx +++ b/components/ui/animated-grid-pattern.tsx @@ -43,7 +43,6 @@ export function AnimatedGridPattern({ ]; } - // Adjust the generateSquares function to return objects with an id, x, and y function generateSquares(count: number) { return Array.from({ length: count }, (_, i) => ({ id: i, @@ -51,7 +50,6 @@ export function AnimatedGridPattern({ })); } - // Function to update a single square's position const updateSquarePosition = (id: number) => { setSquares((currentSquares) => currentSquares.map((sq) => @@ -65,14 +63,12 @@ export function AnimatedGridPattern({ ); }; - // Update squares to animate in useEffect(() => { if (dimensions.width && dimensions.height) { setSquares(generateSquares(numSquares)); } }, [dimensions, numSquares]); - // Resize observer to update container dimensions useEffect(() => { const resizeObserver = new ResizeObserver((entries) => { for (let entry of entries) { diff --git a/components/ui/chart.tsx b/components/ui/chart.tsx index 0510f5bea54235e039e9c3cd648349567782e8f5..a9cbf2862c1db401693a6289621425dce339c7ea 100644 --- a/components/ui/chart.tsx +++ b/components/ui/chart.tsx @@ -1,55 +1,49 @@ -"use client" +"use client"; -import * as React from "react" -import * as RechartsPrimitive from "recharts" -import { - NameType, - Payload, - ValueType, -} from "recharts/types/component/DefaultTooltipContent" +import * as React from "react"; +import * as RechartsPrimitive from "recharts"; -import { cn } from "@/lib/utils" +import { cn } from "@/lib/utils"; -// Format: { THEME_NAME: CSS_SELECTOR } -const THEMES = { light: "", dark: ".dark" } as const +const THEMES = { light: "", dark: ".dark" } as const; export type ChartConfig = { [k in string]: { - label?: React.ReactNode - icon?: React.ComponentType + label?: React.ReactNode; + icon?: React.ComponentType; } & ( | { color?: string; theme?: never } | { color?: never; theme: Record } - ) -} + ); +}; type ChartContextProps = { - config: ChartConfig -} + config: ChartConfig; +}; -const ChartContext = React.createContext(null) +const ChartContext = React.createContext(null); function useChart() { - const context = React.useContext(ChartContext) + const context = React.useContext(ChartContext); if (!context) { - throw new Error("useChart must be used within a ") + throw new Error("useChart must be used within a "); } - return context + return context; } const ChartContainer = React.forwardRef< HTMLDivElement, React.ComponentProps<"div"> & { - config: ChartConfig + config: ChartConfig; children: React.ComponentProps< typeof RechartsPrimitive.ResponsiveContainer - >["children"] + >["children"]; } >(({ id, className, children, config, ...props }, ref) => { - const uniqueId = React.useId() - const chartId = `chart-${id || uniqueId.replace(/:/g, "")}` + const uniqueId = React.useId(); + const chartId = `chart-${id || uniqueId.replace(/:/g, "")}`; return ( @@ -68,17 +62,17 @@ const ChartContainer = React.forwardRef<
- ) -}) -ChartContainer.displayName = "Chart" + ); +}); +ChartContainer.displayName = "Chart"; const ChartStyle = ({ id, config }: { id: string; config: ChartConfig }) => { const colorConfig = Object.entries(config).filter( ([_, config]) => config.theme || config.color - ) + ); if (!colorConfig.length) { - return null + return null; } return ( @@ -92,8 +86,8 @@ ${colorConfig .map(([key, itemConfig]) => { const color = itemConfig.theme?.[theme as keyof typeof itemConfig.theme] || - itemConfig.color - return color ? ` --color-${key}: ${color};` : null + itemConfig.color; + return color ? ` --color-${key}: ${color};` : null; }) .join("\n")} } @@ -102,20 +96,20 @@ ${colorConfig .join("\n"), }} /> - ) -} + ); +}; -const ChartTooltip = RechartsPrimitive.Tooltip +const ChartTooltip = RechartsPrimitive.Tooltip; const ChartTooltipContent = React.forwardRef< HTMLDivElement, React.ComponentProps & React.ComponentProps<"div"> & { - hideLabel?: boolean - hideIndicator?: boolean - indicator?: "line" | "dot" | "dashed" - nameKey?: string - labelKey?: string + hideLabel?: boolean; + hideIndicator?: boolean; + indicator?: "line" | "dot" | "dashed"; + nameKey?: string; + labelKey?: string; } >( ( @@ -136,34 +130,34 @@ const ChartTooltipContent = React.forwardRef< }, ref ) => { - const { config } = useChart() + const { config } = useChart(); const tooltipLabel = React.useMemo(() => { if (hideLabel || !payload?.length) { - return null + return null; } - const [item] = payload - const key = `${labelKey || item.dataKey || item.name || "value"}` - const itemConfig = getPayloadConfigFromPayload(config, item, key) + const [item] = payload; + const key = `${labelKey || item.dataKey || item.name || "value"}`; + const itemConfig = getPayloadConfigFromPayload(config, item, key); const value = !labelKey && typeof label === "string" ? config[label as keyof typeof config]?.label || label - : itemConfig?.label + : itemConfig?.label; if (labelFormatter) { return (
{labelFormatter(value, payload)}
- ) + ); } if (!value) { - return null + return null; } - return
{value}
+ return
{value}
; }, [ label, labelFormatter, @@ -172,13 +166,13 @@ const ChartTooltipContent = React.forwardRef< labelClassName, config, labelKey, - ]) + ]); if (!active || !payload?.length) { - return null + return null; } - const nestLabel = payload.length === 1 && indicator !== "dot" + const nestLabel = payload.length === 1 && indicator !== "dot"; return (
{payload.map((item, index) => { - const key = `${nameKey || item.name || item.dataKey || "value"}` - const itemConfig = getPayloadConfigFromPayload(config, item, key) - const indicatorColor = color || item.payload.fill || item.color + const key = `${nameKey || item.name || item.dataKey || "value"}`; + const itemConfig = getPayloadConfigFromPayload(config, item, key); + const indicatorColor = color || item.payload.fill || item.color; return (
)}
- ) + ); })}
- ) + ); } -) -ChartTooltipContent.displayName = "ChartTooltip" +); +ChartTooltipContent.displayName = "ChartTooltip"; -const ChartLegend = RechartsPrimitive.Legend +const ChartLegend = RechartsPrimitive.Legend; const ChartLegendContent = React.forwardRef< HTMLDivElement, React.ComponentProps<"div"> & Pick & { - hideIcon?: boolean - nameKey?: string + hideIcon?: boolean; + nameKey?: string; } >( ( { className, hideIcon = false, payload, verticalAlign = "bottom", nameKey }, ref ) => { - const { config } = useChart() + const { config } = useChart(); if (!payload?.length) { - return null + return null; } return ( @@ -291,8 +285,8 @@ const ChartLegendContent = React.forwardRef< )} > {payload.map((item) => { - const key = `${nameKey || item.dataKey || "value"}` - const itemConfig = getPayloadConfigFromPayload(config, item, key) + const key = `${nameKey || item.dataKey || "value"}`; + const itemConfig = getPayloadConfigFromPayload(config, item, key); return (
- ) + ); })}
- ) + ); } -) -ChartLegendContent.displayName = "ChartLegend" +); +ChartLegendContent.displayName = "ChartLegend"; -// Helper to extract item config from a payload. function getPayloadConfigFromPayload( config: ChartConfig, payload: unknown, key: string ) { if (typeof payload !== "object" || payload === null) { - return undefined + return undefined; } const payloadPayload = @@ -336,15 +329,15 @@ function getPayloadConfigFromPayload( typeof payload.payload === "object" && payload.payload !== null ? payload.payload - : undefined + : undefined; - let configLabelKey: string = key + let configLabelKey: string = key; if ( key in payload && typeof payload[key as keyof typeof payload] === "string" ) { - configLabelKey = payload[key as keyof typeof payload] as string + configLabelKey = payload[key as keyof typeof payload] as string; } else if ( payloadPayload && key in payloadPayload && @@ -352,12 +345,12 @@ function getPayloadConfigFromPayload( ) { configLabelKey = payloadPayload[ key as keyof typeof payloadPayload - ] as string + ] as string; } return configLabelKey in config ? config[configLabelKey] - : config[key as keyof typeof config] + : config[key as keyof typeof config]; } export { @@ -367,4 +360,4 @@ export { ChartLegend, ChartLegendContent, ChartStyle, -} +}; diff --git a/components/ui/sidebar.tsx b/components/ui/sidebar.tsx index 921f05c1face5ffbfca9b87dfc66e01c33097a52..aaec6bf0f7f6e358fa49756897a94d604cd5433f 100644 --- a/components/ui/sidebar.tsx +++ b/components/ui/sidebar.tsx @@ -70,8 +70,6 @@ const SidebarProvider = React.forwardRef< const isMobile = useIsMobile(); const [openMobile, setOpenMobile] = React.useState(false); - // This is the internal state of the sidebar. - // We use openProp and setOpenProp for control from outside the component. const [_open, _setOpen] = React.useState(defaultOpen); const open = openProp ?? _open; const setOpen = React.useCallback( @@ -83,20 +81,17 @@ const SidebarProvider = React.forwardRef< _setOpen(openState); } - // This sets the cookie to keep the sidebar state. document.cookie = `${SIDEBAR_COOKIE_NAME}=${openState}; path=/; max-age=${SIDEBAR_COOKIE_MAX_AGE}`; }, [setOpenProp, open] ); - // Helper to toggle the sidebar. const toggleSidebar = React.useCallback(() => { return isMobile ? setOpenMobile((open) => !open) : setOpen((open) => !open); }, [isMobile, setOpen, setOpenMobile]); - // Adds a keyboard shortcut to toggle the sidebar. React.useEffect(() => { const handleKeyDown = (event: KeyboardEvent) => { if ( @@ -112,8 +107,6 @@ const SidebarProvider = React.forwardRef< return () => window.removeEventListener("keydown", handleKeyDown); }, [toggleSidebar]); - // We add a state so that we can do data-state="expanded" or "collapsed". - // This makes it easier to style the sidebar with Tailwind classes. const state = open ? "expanded" : "collapsed"; const contextValue = React.useMemo( @@ -221,7 +214,6 @@ const Sidebar = React.forwardRef< data-variant={variant} data-side={side} > - {/* This is what handles the sidebar gap on desktop */}
(({ className, showIcon = false, ...props }, ref) => { - // Random width between 50 to 90%. const width = React.useMemo(() => { return `${Math.floor(Math.random() * 40) + 50}%`; }, []); diff --git a/hooks/use-toast.ts b/hooks/use-toast.ts index 02e111d81dd774038ac483c11b5f5a8f8aceb024..44bfdbdd479caf03efbbfb5d1d7499b40024f0b2 100644 --- a/hooks/use-toast.ts +++ b/hooks/use-toast.ts @@ -1,78 +1,74 @@ -"use client" +"use client"; -// Inspired by react-hot-toast library -import * as React from "react" +import * as React from "react"; -import type { - ToastActionElement, - ToastProps, -} from "@/components/ui/toast" +import type { ToastActionElement, ToastProps } from "@/components/ui/toast"; -const TOAST_LIMIT = 1 -const TOAST_REMOVE_DELAY = 1000000 +const TOAST_LIMIT = 1; +const TOAST_REMOVE_DELAY = 1000000; type ToasterToast = ToastProps & { - id: string - title?: React.ReactNode - description?: React.ReactNode - action?: ToastActionElement -} + id: string; + title?: React.ReactNode; + description?: React.ReactNode; + action?: ToastActionElement; +}; const actionTypes = { ADD_TOAST: "ADD_TOAST", UPDATE_TOAST: "UPDATE_TOAST", DISMISS_TOAST: "DISMISS_TOAST", REMOVE_TOAST: "REMOVE_TOAST", -} as const +} as const; -let count = 0 +let count = 0; function genId() { - count = (count + 1) % Number.MAX_SAFE_INTEGER - return count.toString() + count = (count + 1) % Number.MAX_SAFE_INTEGER; + return count.toString(); } -type ActionType = typeof actionTypes +type ActionType = typeof actionTypes; type Action = | { - type: ActionType["ADD_TOAST"] - toast: ToasterToast + type: ActionType["ADD_TOAST"]; + toast: ToasterToast; } | { - type: ActionType["UPDATE_TOAST"] - toast: Partial + type: ActionType["UPDATE_TOAST"]; + toast: Partial; } | { - type: ActionType["DISMISS_TOAST"] - toastId?: ToasterToast["id"] + type: ActionType["DISMISS_TOAST"]; + toastId?: ToasterToast["id"]; } | { - type: ActionType["REMOVE_TOAST"] - toastId?: ToasterToast["id"] - } + type: ActionType["REMOVE_TOAST"]; + toastId?: ToasterToast["id"]; + }; interface State { - toasts: ToasterToast[] + toasts: ToasterToast[]; } -const toastTimeouts = new Map>() +const toastTimeouts = new Map>(); const addToRemoveQueue = (toastId: string) => { if (toastTimeouts.has(toastId)) { - return + return; } const timeout = setTimeout(() => { - toastTimeouts.delete(toastId) + toastTimeouts.delete(toastId); dispatch({ type: "REMOVE_TOAST", toastId: toastId, - }) - }, TOAST_REMOVE_DELAY) + }); + }, TOAST_REMOVE_DELAY); - toastTimeouts.set(toastId, timeout) -} + toastTimeouts.set(toastId, timeout); +}; export const reducer = (state: State, action: Action): State => { switch (action.type) { @@ -80,7 +76,7 @@ export const reducer = (state: State, action: Action): State => { return { ...state, toasts: [action.toast, ...state.toasts].slice(0, TOAST_LIMIT), - } + }; case "UPDATE_TOAST": return { @@ -88,19 +84,17 @@ export const reducer = (state: State, action: Action): State => { toasts: state.toasts.map((t) => t.id === action.toast.id ? { ...t, ...action.toast } : t ), - } + }; case "DISMISS_TOAST": { - const { toastId } = action + const { toastId } = action; - // ! Side effects ! - This could be extracted into a dismissToast() action, - // but I'll keep it here for simplicity if (toastId) { - addToRemoveQueue(toastId) + addToRemoveQueue(toastId); } else { state.toasts.forEach((toast) => { - addToRemoveQueue(toast.id) - }) + addToRemoveQueue(toast.id); + }); } return { @@ -113,44 +107,44 @@ export const reducer = (state: State, action: Action): State => { } : t ), - } + }; } case "REMOVE_TOAST": if (action.toastId === undefined) { return { ...state, toasts: [], - } + }; } return { ...state, toasts: state.toasts.filter((t) => t.id !== action.toastId), - } + }; } -} +}; -const listeners: Array<(state: State) => void> = [] +const listeners: Array<(state: State) => void> = []; -let memoryState: State = { toasts: [] } +let memoryState: State = { toasts: [] }; function dispatch(action: Action) { - memoryState = reducer(memoryState, action) + memoryState = reducer(memoryState, action); listeners.forEach((listener) => { - listener(memoryState) - }) + listener(memoryState); + }); } -type Toast = Omit +type Toast = Omit; function toast({ ...props }: Toast) { - const id = genId() + const id = genId(); const update = (props: ToasterToast) => dispatch({ type: "UPDATE_TOAST", toast: { ...props, id }, - }) - const dismiss = () => dispatch({ type: "DISMISS_TOAST", toastId: id }) + }); + const dismiss = () => dispatch({ type: "DISMISS_TOAST", toastId: id }); dispatch({ type: "ADD_TOAST", @@ -159,36 +153,36 @@ function toast({ ...props }: Toast) { id, open: true, onOpenChange: (open) => { - if (!open) dismiss() + if (!open) dismiss(); }, }, - }) + }); return { id: id, dismiss, update, - } + }; } function useToast() { - const [state, setState] = React.useState(memoryState) + const [state, setState] = React.useState(memoryState); React.useEffect(() => { - listeners.push(setState) + listeners.push(setState); return () => { - const index = listeners.indexOf(setState) + const index = listeners.indexOf(setState); if (index > -1) { - listeners.splice(index, 1) + listeners.splice(index, 1); } - } - }, [state]) + }; + }, [state]); return { ...state, toast, dismiss: (toastId?: string) => dispatch({ type: "DISMISS_TOAST", toastId }), - } + }; } -export { useToast, toast } +export { useToast, toast }; diff --git a/lib/auth.ts b/lib/auth.ts new file mode 100644 index 0000000000000000000000000000000000000000..37c5070409e8fcd6c980c7c96518de8e7079d384 --- /dev/null +++ b/lib/auth.ts @@ -0,0 +1,23 @@ +import { NextResponse } from "next/server"; + +const ACCESS_TOKEN = process.env.ACCESS_TOKEN; + +export function verifyApiToken(req: Request) { + if (!ACCESS_TOKEN) { + console.error("ACCESS_TOKEN is not set"); + return NextResponse.json( + { error: "Server configuration error" }, + { status: 500 } + ); + } + + const authHeader = req.headers.get("authorization"); + const token = authHeader?.replace("Bearer ", ""); + + if (!token || token !== ACCESS_TOKEN) { + console.log("Unauthorized access attempt"); + return NextResponse.json({ error: "Unauthorized" }, { status: 401 }); + } + + return null; +} diff --git a/lib/dayjs.ts b/lib/dayjs.ts new file mode 100644 index 0000000000000000000000000000000000000000..9a2c56e683ff973294a088a26551f19c91072356 --- /dev/null +++ b/lib/dayjs.ts @@ -0,0 +1,19 @@ +import dayjs from "dayjs"; +import utc from "dayjs/plugin/utc"; +import timezone from "dayjs/plugin/timezone"; + +dayjs.extend(utc); +dayjs.extend(timezone); + +const localTimezone = Intl.DateTimeFormat().resolvedOptions().timeZone; +dayjs.tz.setDefault(localTimezone); + +const originalFormat = dayjs.prototype.format; +dayjs.prototype.format = function (template: string) { + if (template === "YYYY-MM-DDTHH:mm:ssZ") { + return this.toISOString(); + } + return originalFormat.call(this, template); +}; + +export default dayjs; diff --git a/lib/db.ts b/lib/db.ts deleted file mode 100644 index fba36f4fb7f36c01b3bb5276a4079c957d49c9b4..0000000000000000000000000000000000000000 --- a/lib/db.ts +++ /dev/null @@ -1,352 +0,0 @@ -import { Pool, PoolClient } from "pg"; - -// 构建数据库连接配置 -const dbConfig = process.env.POSTGRES_URL - ? { - // 远程数据库配置 - connectionString: process.env.POSTGRES_URL, - ssl: { - rejectUnauthorized: false, // 允许自签名证书 - }, - } - : { - // 本地 Docker 数据库配置 - host: process.env.POSTGRES_HOST || "localhost", - user: process.env.POSTGRES_USER || "postgres", - password: process.env.POSTGRES_PASSWORD, - database: process.env.POSTGRES_DATABASE || "openwebui_monitor", - ssl: false, - }; - -// 创建连接池 -export const pool = new Pool(dbConfig); - -// 测试连接 -pool.on("error", (err) => { - console.error("Unexpected error on idle client", err); - process.exit(-1); -}); - -// 数据库行的类型定义 -interface ModelPriceRow { - id: string; - name: string; - input_price: string | number; - output_price: string | number; - per_msg_price: string | number; - updated_at: Date; -} - -export interface ModelPrice { - id: string; - name: string; - input_price: number; - output_price: number; - per_msg_price: number; - updated_at: Date; -} - -export interface UserUsageRecord { - id: number; - userId: number; - nickname: string; - useTime: Date; - modelName: string; - inputTokens: number; - outputTokens: number; - cost: number; - balanceAfter: number; -} - -// 确保表存在 -export async function ensureTablesExist() { - let client: PoolClient | null = null; - try { - client = await pool.connect(); - - // 首先创建 users 表 - await client.query(` - CREATE TABLE IF NOT EXISTS users ( - id TEXT PRIMARY KEY, - email TEXT NOT NULL, - name TEXT NOT NULL, - role TEXT NOT NULL DEFAULT 'user', - balance DECIMAL(16, 6) NOT NULL DEFAULT 0, - created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP - ); - `); - - // 获取默认价格 - const defaultInputPrice = parseFloat( - process.env.DEFAULT_MODEL_INPUT_PRICE || "60" - ); - const defaultOutputPrice = parseFloat( - process.env.DEFAULT_MODEL_OUTPUT_PRICE || "60" - ); - const defaultPerMsgPrice = parseFloat( - process.env.DEFAULT_MODEL_PER_MSG_PRICE || "-1" - ); - - // 然后创建 model_prices 表,使用具体的默认值而不是参数绑定 - await client.query(` - CREATE TABLE IF NOT EXISTS model_prices ( - id TEXT PRIMARY KEY, - name TEXT NOT NULL, - base_model_id TEXT, - input_price NUMERIC(10, 6) DEFAULT ${defaultInputPrice}, - output_price NUMERIC(10, 6) DEFAULT ${defaultOutputPrice}, - per_msg_price NUMERIC(10, 6) DEFAULT ${defaultPerMsgPrice}, - updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP - ); - `); - - // 检查并添加 per_msg_price 列(如果不存在) - await client.query(` - DO $$ - BEGIN - BEGIN - ALTER TABLE model_prices - ADD COLUMN per_msg_price NUMERIC(10, 6) DEFAULT ${defaultPerMsgPrice}; - EXCEPTION - WHEN duplicate_column THEN NULL; - END; - END $$; - `); - - // 检查并添加 base_model_id 列(如果不存在) - await client.query(` - DO $$ - BEGIN - BEGIN - ALTER TABLE model_prices - ADD COLUMN base_model_id TEXT; - EXCEPTION - WHEN duplicate_column THEN NULL; - END; - END $$; - `); - - // 最后创建 user_usage_records 表 - await client.query(` - CREATE TABLE IF NOT EXISTS user_usage_records ( - id SERIAL PRIMARY KEY, - user_id TEXT NOT NULL, - nickname VARCHAR(255) NOT NULL, - use_time TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, - model_name VARCHAR(255) NOT NULL, - input_tokens INTEGER NOT NULL, - output_tokens INTEGER NOT NULL, - cost DECIMAL(10, 4) NOT NULL, - balance_after DECIMAL(10, 4) NOT NULL, - FOREIGN KEY (user_id) REFERENCES users(id) - ); - `); - } catch (error) { - console.error("Database connection/initialization error:", error); - throw error; - } finally { - if (client) { - client.release(); - } - } -} - -// 获取模型价格,如果不存在则创建默认值 -export async function getOrCreateModelPrices( - models: Array<{ id: string; name: string; base_model_id?: string }> -): Promise { - let client: PoolClient | null = null; - try { - client = await pool.connect(); - - // 获取默认价格 - const defaultInputPrice = parseFloat( - process.env.DEFAULT_MODEL_INPUT_PRICE || "60" - ); - const defaultOutputPrice = parseFloat( - process.env.DEFAULT_MODEL_OUTPUT_PRICE || "60" - ); - const defaultPerMsgPrice = parseFloat( - process.env.DEFAULT_MODEL_PER_MSG_PRICE || "-1" - ); - - // 1. 首先获取所有已存在的模型价格 - const modelIds = models.map((m) => m.id); - const baseModelIds = models.map((m) => m.base_model_id).filter((id) => id); - - const existingModelsResult = await client.query( - `SELECT * FROM model_prices WHERE id = ANY($1::text[])`, - [modelIds] - ); - - // 2. 获取所有基础模型的价格 - const baseModelsResult = await client.query( - `SELECT * FROM model_prices WHERE id = ANY($1::text[])`, - [baseModelIds] - ); - - const existingModels = new Map( - existingModelsResult.rows.map((row) => [row.id, row]) - ); - const baseModels = new Map( - baseModelsResult.rows.map((row) => [row.id, row]) - ); - - // 3. 更新所有模型的名称并插入缺失的模型 - const modelsToUpdate = models.filter((m) => existingModels.has(m.id)); - const missingModels = models.filter((m) => !existingModels.has(m.id)); - - // 3.1 更新现有模型的名称 - if (modelsToUpdate.length > 0) { - for (const model of modelsToUpdate) { - await client.query(`UPDATE model_prices SET name = $2 WHERE id = $1`, [ - model.id, - model.name, - ]); - } - } - - // 3.2 插入缺失的模型 - if (missingModels.length > 0) { - const values = missingModels.map((m) => { - const baseModel = m.base_model_id - ? baseModels.get(m.base_model_id) - : null; - return [ - m.id, - m.name, - baseModel?.input_price ?? defaultInputPrice, - baseModel?.output_price ?? defaultOutputPrice, - baseModel?.per_msg_price ?? defaultPerMsgPrice, - ]; - }); - - const placeholders = values - .map( - (_, i) => - `($${i * 5 + 1}, $${i * 5 + 2}, $${i * 5 + 3}, $${i * 5 + 4}, $${ - i * 5 + 5 - })` - ) - .join(","); - - const result = await client.query( - `INSERT INTO model_prices (id, name, input_price, output_price, per_msg_price) - VALUES ${placeholders} - RETURNING *`, - values.flat() - ); - - result.rows.forEach((row) => existingModels.set(row.id, row)); - } - - // 4. 重新获取所有模型的最新数据 - const updatedModelsResult = await client.query( - `SELECT * FROM model_prices WHERE id = ANY($1::text[])`, - [modelIds] - ); - - const updatedModels = new Map( - updatedModelsResult.rows.map((row) => [row.id, row]) - ); - - return models.map((m) => { - const row = updatedModels.get(m.id)!; - return { - id: row.id, - name: row.name, - input_price: Number(row.input_price), - output_price: Number(row.output_price), - per_msg_price: Number(row.per_msg_price), - updated_at: row.updated_at, - }; - }); - } catch (error) { - console.error("Error in getOrCreateModelPrices:", error); - throw error; - } finally { - if (client) { - client.release(); - } - } -} - -// 更新模型价格 -export async function updateModelPrice( - id: string, - input_price: number, - output_price: number, - per_msg_price: number -): Promise { - let client: PoolClient | null = null; - try { - client = await pool.connect(); - - // 使用 CAST 确保数据类型正确 - const result = await client.query( - `UPDATE model_prices - SET - input_price = CAST($2 AS NUMERIC(10,6)), - output_price = CAST($3 AS NUMERIC(10,6)), - per_msg_price = CAST($4 AS NUMERIC(10,6)), - updated_at = CURRENT_TIMESTAMP - WHERE id = $1 - RETURNING *`, - [id, input_price, output_price, per_msg_price] - ); - - if (result.rows[0]) { - return { - id: result.rows[0].id, - name: result.rows[0].name, - input_price: Number(result.rows[0].input_price), - output_price: Number(result.rows[0].output_price), - per_msg_price: Number(result.rows[0].per_msg_price), - updated_at: result.rows[0].updated_at, - }; - } - return null; - } catch (error) { - console.error(`Failed to update ${id} price:`, error); - throw error; - } finally { - if (client) { - client.release(); - } - } -} - -// 添加一个始化函数 -export async function initDatabase() { - try { - await ensureTablesExist(); - // console.log("Database initialized successfully"); - } catch (error) { - console.error("Failed to initialize database:", error); - throw error; - } -} - -// 更新用户余额 -export async function updateUserBalance(userId: string, balance: number) { - let client: PoolClient | null = null; - try { - client = await pool.connect(); - const result = await client.query( - `UPDATE users - SET balance = $2 - WHERE id = $1 - RETURNING id, email, balance`, - [userId, balance] - ); - - return result.rows[0]; - } catch (error) { - console.error("Error in updateUserBalance:", error); - throw error; - } finally { - if (client) { - client.release(); - } - } -} diff --git a/lib/db/client.ts b/lib/db/client.ts index fbd5b54fe00b2b4fe5a2a8972787d57e1c32e5fc..1f8420000f0f5a33af774ac1912a326e18e26f0e 100644 --- a/lib/db/client.ts +++ b/lib/db/client.ts @@ -6,7 +6,6 @@ import { Pool, PoolClient } from "pg"; const isVercel = process.env.VERCEL === "1"; -// 为 Vercel 环境添加连接池 let vercelPool: { client: ReturnType; isConnected: boolean; @@ -21,7 +20,6 @@ async function getVercelClient() { }; } - // 如果还没连接,则建立连接 if (!vercelPool.isConnected) { try { await vercelPool.client.connect(); @@ -48,7 +46,8 @@ function getClient() { port: parseInt(process.env.POSTGRES_PORT || "5432"), max: 20, idleTimeoutMillis: 30000, - connectionTimeoutMillis: 2000, + connectionTimeoutMillis: 30000, + statement_timeout: 30000, }; if (process.env.POSTGRES_URL) { @@ -59,7 +58,8 @@ function getClient() { }, max: 20, idleTimeoutMillis: 30000, - connectionTimeoutMillis: 2000, + connectionTimeoutMillis: 30000, + statement_timeout: 30000, }); } else { pgPool = new Pool(config); @@ -74,13 +74,11 @@ function getClient() { } } -// 定义一个通用的查询结果类型 type CommonQueryResult = { rows: T[]; rowCount: number; }; -// 导出一个通用的查询函数 export async function query( text: string, params?: any[] @@ -100,7 +98,6 @@ export async function query( }; } catch (error) { console.error("[DB Query Error]", error); - // 如果连接出错,重置连接状态 if (vercelPool) { vercelPool.isConnected = false; } @@ -128,7 +125,6 @@ export async function query( } } -// 确保在应用关闭时清理连接 if (typeof window === "undefined") { process.on("SIGTERM", async () => { console.log("SIGTERM received, closing database connections"); @@ -143,3 +139,347 @@ if (typeof window === "undefined") { } export { getClient }; + +export async function ensureTablesExist() { + try { + const usersTableExists = await query(` + SELECT EXISTS ( + SELECT FROM information_schema.tables + WHERE table_name = 'users' + ); + `); + + if (!usersTableExists.rows[0].exists) { + await query(` + CREATE TABLE IF NOT EXISTS users ( + id TEXT PRIMARY KEY, + email TEXT NOT NULL, + name TEXT NOT NULL, + role TEXT NOT NULL DEFAULT 'user', + balance DECIMAL(16, 6) NOT NULL DEFAULT 0, + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + deleted BOOLEAN DEFAULT FALSE + ); + `); + } else { + try { + await query(` + DO $$ + BEGIN + BEGIN + ALTER TABLE users + ADD COLUMN deleted BOOLEAN DEFAULT FALSE; + EXCEPTION + WHEN duplicate_column THEN NULL; + END; + END $$; + `); + } catch (error) { + console.error("Error adding deleted column to users table:", error); + } + } + + const modelPricesTableExists = await query(` + SELECT EXISTS ( + SELECT FROM information_schema.tables + WHERE table_name = 'model_prices' + ); + `); + + const defaultInputPrice = parseFloat( + process.env.DEFAULT_MODEL_INPUT_PRICE || "60" + ); + const defaultOutputPrice = parseFloat( + process.env.DEFAULT_MODEL_OUTPUT_PRICE || "60" + ); + const defaultPerMsgPrice = parseFloat( + process.env.DEFAULT_MODEL_PER_MSG_PRICE || "-1" + ); + + if (!modelPricesTableExists.rows[0].exists) { + await query(` + CREATE TABLE IF NOT EXISTS model_prices ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + base_model_id TEXT, + input_price NUMERIC(10, 6) DEFAULT ${defaultInputPrice}, + output_price NUMERIC(10, 6) DEFAULT ${defaultOutputPrice}, + per_msg_price NUMERIC(10, 6) DEFAULT ${defaultPerMsgPrice}, + updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP + ); + `); + } else { + try { + await query(` + DO $$ + BEGIN + BEGIN + ALTER TABLE model_prices + ADD COLUMN per_msg_price NUMERIC(10, 6) DEFAULT ${defaultPerMsgPrice}; + EXCEPTION + WHEN duplicate_column THEN NULL; + END; + END $$; + `); + } catch (error) { + console.error("Error adding per_msg_price column:", error); + } + + try { + await query(` + DO $$ + BEGIN + BEGIN + ALTER TABLE model_prices + ADD COLUMN base_model_id TEXT; + EXCEPTION + WHEN duplicate_column THEN NULL; + END; + END $$; + `); + } catch (error) { + console.error("Error adding base_model_id column:", error); + } + } + + const userUsageRecordsTableExists = await query(` + SELECT EXISTS ( + SELECT FROM information_schema.tables + WHERE table_name = 'user_usage_records' + ); + `); + + if (!userUsageRecordsTableExists.rows[0].exists) { + await query(` + CREATE TABLE IF NOT EXISTS user_usage_records ( + id SERIAL PRIMARY KEY, + user_id TEXT NOT NULL, + nickname VARCHAR(255) NOT NULL, + use_time TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + model_name VARCHAR(255) NOT NULL, + input_tokens INTEGER NOT NULL, + output_tokens INTEGER NOT NULL, + cost DECIMAL(10, 4) NOT NULL, + balance_after DECIMAL(10, 4) NOT NULL, + FOREIGN KEY (user_id) REFERENCES users(id) + ); + `); + } + + console.log("Database tables initialized successfully"); + } catch (error) { + console.error("Failed to initialize database tables:", error); + throw error; + } +} + +export async function initDatabase() { + try { + await ensureTablesExist(); + console.log("Database initialized successfully"); + } catch (error) { + console.error("Failed to initialize database:", error); + throw error; + } +} + +export interface ModelPrice { + id: string; + name: string; + input_price: number; + output_price: number; + per_msg_price: number; + updated_at: Date; +} + +export interface UserUsageRecord { + id: number; + userId: number; + nickname: string; + useTime: Date; + modelName: string; + inputTokens: number; + outputTokens: number; + cost: number; + balanceAfter: number; +} + +export async function getOrCreateModelPrices( + models: Array<{ id: string; name: string; base_model_id?: string }> +): Promise { + try { + const defaultInputPrice = parseFloat( + process.env.DEFAULT_MODEL_INPUT_PRICE || "60" + ); + const defaultOutputPrice = parseFloat( + process.env.DEFAULT_MODEL_OUTPUT_PRICE || "60" + ); + const defaultPerMsgPrice = parseFloat( + process.env.DEFAULT_MODEL_PER_MSG_PRICE || "-1" + ); + + const modelIds = models.map((m) => m.id); + const baseModelIds = models.map((m) => m.base_model_id).filter((id) => id); + + const existingModelsResult = await query( + `SELECT * FROM model_prices WHERE id = ANY($1::text[])`, + [modelIds] + ); + + const baseModelsResult = await query( + `SELECT * FROM model_prices WHERE id = ANY($1::text[])`, + [baseModelIds] + ); + + const existingModels = new Map( + existingModelsResult.rows.map((row) => [row.id, row]) + ); + const baseModels = new Map( + baseModelsResult.rows.map((row) => [row.id, row]) + ); + + const modelsToUpdate = models.filter((m) => existingModels.has(m.id)); + const missingModels = models.filter((m) => !existingModels.has(m.id)); + + if (modelsToUpdate.length > 0) { + for (const model of modelsToUpdate) { + await query(`UPDATE model_prices SET name = $2 WHERE id = $1`, [ + model.id, + model.name, + ]); + } + } + + if (missingModels.length > 0) { + for (const model of missingModels) { + const baseModel = model.base_model_id + ? baseModels.get(model.base_model_id) + : null; + + await query( + `INSERT INTO model_prices (id, name, input_price, output_price, per_msg_price) + VALUES ($1, $2, $3, $4, $5) + RETURNING *`, + [ + model.id, + model.name, + baseModel?.input_price ?? defaultInputPrice, + baseModel?.output_price ?? defaultOutputPrice, + baseModel?.per_msg_price ?? defaultPerMsgPrice, + ] + ); + } + } + + const updatedModelsResult = await query( + `SELECT * FROM model_prices WHERE id = ANY($1::text[])`, + [modelIds] + ); + + return updatedModelsResult.rows.map((row) => ({ + id: row.id, + name: row.name, + input_price: Number(row.input_price), + output_price: Number(row.output_price), + per_msg_price: Number(row.per_msg_price), + updated_at: row.updated_at, + })); + } catch (error) { + console.error("Error in getOrCreateModelPrices:", error); + throw error; + } +} + +export async function updateModelPrice( + id: string, + input_price: number, + output_price: number, + per_msg_price: number +): Promise { + try { + const result = await query( + `UPDATE model_prices + SET + input_price = CAST($2 AS NUMERIC(10,6)), + output_price = CAST($3 AS NUMERIC(10,6)), + per_msg_price = CAST($4 AS NUMERIC(10,6)), + updated_at = CURRENT_TIMESTAMP + WHERE id = $1 + RETURNING *`, + [id, input_price, output_price, per_msg_price] + ); + + if (result.rows[0]) { + return { + id: result.rows[0].id, + name: result.rows[0].model_name, + input_price: Number(result.rows[0].input_price), + output_price: Number(result.rows[0].output_price), + per_msg_price: Number(result.rows[0].per_msg_price), + updated_at: result.rows[0].updated_at, + }; + } + return null; + } catch (error) { + console.error("Error updating model price:", error); + throw error; + } +} + +export async function updateUserBalance(userId: string, balance: number) { + try { + const result = await query( + `UPDATE users + SET balance = $2 + WHERE id = $1 + RETURNING id, email, balance`, + [userId, balance] + ); + + return result.rows[0]; + } catch (error) { + console.error("Error in updateUserBalance:", error); + throw error; + } +} + +export const pool = { + connect: async () => { + if (isVercel) { + return { + query: async (text: string, params?: any[]) => { + const client = await getVercelClient(); + const result = await client.query({ + text, + values: params || [], + }); + return result; + }, + release: () => {}, + }; + } else { + return (pgPool || (getClient() as Pool)).connect(); + } + }, + query: async (text: string, params?: any[]) => { + if (isVercel) { + const client = await getVercelClient(); + return client.query({ + text, + values: params || [], + }); + } else { + return (pgPool || (getClient() as Pool)).query(text, params); + } + }, + end: async () => { + if (isVercel) { + if (vercelPool?.client) { + await vercelPool.client.end(); + vercelPool.isConnected = false; + } + } else if (pgPool) { + await pgPool.end(); + } + }, +}; diff --git a/lib/db/index.ts b/lib/db/index.ts index cdd9fa50ef8fdc7fb4570ebfd1ad54e5136901e1..f7f4adceaeaf0d76acc9b2a97e3f836b9f72b271 100644 --- a/lib/db/index.ts +++ b/lib/db/index.ts @@ -1,8 +1,7 @@ import { query } from "./client"; import { ensureUserTableExists } from "./users"; -import { ModelPrice } from "../db"; +import { ModelPrice, updateModelPrice } from "./client"; -// 创建模型价格表 async function ensureModelPricesTableExists() { const defaultInputPrice = parseFloat( process.env.DEFAULT_MODEL_INPUT_PRICE || "60" @@ -16,7 +15,7 @@ async function ensureModelPricesTableExists() { await query( `CREATE TABLE IF NOT EXISTS model_prices ( - model_id TEXT PRIMARY KEY, + id TEXT PRIMARY KEY, model_name TEXT NOT NULL, input_price DECIMAL(10, 6) DEFAULT CAST($1 AS DECIMAL(10, 6)), output_price DECIMAL(10, 6) DEFAULT CAST($2 AS DECIMAL(10, 6)), @@ -26,7 +25,6 @@ async function ensureModelPricesTableExists() { [defaultInputPrice, defaultOutputPrice, defaultPerMsgPrice] ); - // 为现有记录添加 per_msg_price 字段(如果不存在) await query( `DO $$ BEGIN @@ -56,16 +54,16 @@ export async function getOrCreateModelPrice( ); const result = await query( - `INSERT INTO model_prices (model_id, model_name, per_msg_price, updated_at) + `INSERT INTO model_prices (id, model_name, per_msg_price, updated_at) VALUES ($1, $2, CAST($3 AS DECIMAL(10, 6)), CURRENT_TIMESTAMP) - ON CONFLICT (model_id) DO UPDATE + ON CONFLICT (id) DO UPDATE SET model_name = $2, updated_at = CURRENT_TIMESTAMP RETURNING *`, [id, name, defaultPerMsgPrice] ); return { - id: result.rows[0].model_id, + id: result.rows[0].id, name: result.rows[0].model_name, input_price: Number(result.rows[0].input_price), output_price: Number(result.rows[0].output_price), @@ -83,37 +81,6 @@ export async function getOrCreateModelPrice( } } -export async function updateModelPrice( - modelId: string, - input_price: number, - output_price: number, - per_msg_price: number -): Promise { - const result = await query( - `UPDATE model_prices - SET - input_price = CAST($2 AS DECIMAL(10,6)), - output_price = CAST($3 AS DECIMAL(10,6)), - per_msg_price = CAST($4 AS DECIMAL(10,6)), - updated_at = CURRENT_TIMESTAMP - WHERE model_id = $1 - RETURNING *;`, - [modelId, input_price, output_price, per_msg_price] - ); - - if (result.rows.length === 0) { - return null; - } - - return { - id: result.rows[0].model_id, - name: result.rows[0].model_name, - input_price: Number(result.rows[0].input_price), - output_price: Number(result.rows[0].output_price), - per_msg_price: Number(result.rows[0].per_msg_price), - updated_at: result.rows[0].updated_at, - }; -} export { getUsers, getOrCreateUser, diff --git a/lib/utils/inlet-cost.ts b/lib/utils/inlet-cost.ts index 0cc6839d63523fecfe44d003794b152f31fc060c..15a163cd7b8334811f968486fc1139be1acadce0 100644 --- a/lib/utils/inlet-cost.ts +++ b/lib/utils/inlet-cost.ts @@ -5,13 +5,11 @@ interface ModelInletCost { function parseInletCostConfig(config: string | undefined): ModelInletCost { if (!config) return {}; - // 如果配置是一个数字,对所有模型使用相同的预扣费 const numericValue = Number(config); if (!isNaN(numericValue)) { return { default: numericValue }; } - // 否则解析 model1:0.32,model2:0.01 格式 try { const costs: ModelInletCost = {}; config.split(",").forEach((pair) => { diff --git a/lib/version.ts b/lib/version.ts index 445b0d5346771e3235bbeb2b5b29372dac331ac6..6c6d1af8899a4885fb3f74430cdc9e63d6b2b199 100644 --- a/lib/version.ts +++ b/lib/version.ts @@ -1,2 +1 @@ -// 从环境变量获取版本号 export const APP_VERSION = process.env.NEXT_PUBLIC_APP_VERSION || "0.0.1"; diff --git a/middleware.ts b/middleware.ts index d42c57a0ca4774f21bcf3af0630e0979e9dafe6c..a2713cb48b50343640d070040eeaaff6af54bfba 100644 --- a/middleware.ts +++ b/middleware.ts @@ -7,15 +7,24 @@ const ACCESS_TOKEN = process.env.ACCESS_TOKEN; export async function middleware(request: NextRequest) { const { pathname } = request.nextUrl; - // 只验证 inlet/outlet/test API 请求 if ( pathname.startsWith("/api/v1/inlet") || pathname.startsWith("/api/v1/outlet") || - pathname.startsWith("/api/v1/models/test") + pathname.startsWith("/api/v1/models") || + pathname.startsWith("/api/v1/panel") || + pathname.startsWith("/api/v1/config") || + pathname.startsWith("/api/v1/users") ) { - // API 请求验证 - if (!API_KEY) { - console.error("API Key is not set"); + const token = + pathname.startsWith("/api/v1/panel") || + pathname.startsWith("/api/v1/config") || + pathname.startsWith("/api/v1/users") || + pathname.startsWith("/api/v1/models") + ? ACCESS_TOKEN + : API_KEY; + + if (!token) { + console.error("API Key or Access Token is not set"); return NextResponse.json( { error: "Server configuration error" }, { status: 500 } @@ -25,14 +34,13 @@ export async function middleware(request: NextRequest) { const authHeader = request.headers.get("authorization"); const providedKey = authHeader?.replace("Bearer ", ""); - if (!providedKey || providedKey !== API_KEY) { - console.log("Invalid API key"); - return NextResponse.json({ error: "Invalid API key" }, { status: 401 }); + if (!providedKey || providedKey !== token) { + console.log("Invalid API key or token"); + return NextResponse.json({ error: "Unauthorized" }, { status: 401 }); } return NextResponse.next(); } else if (!pathname.startsWith("/api/")) { - // 页面访问验证 if (!ACCESS_TOKEN) { console.error("ACCESS_TOKEN is not set"); return NextResponse.json( @@ -41,12 +49,10 @@ export async function middleware(request: NextRequest) { ); } - // 如果是令牌验证页面,直接允许访问 if (pathname === "/token") { return NextResponse.next(); } - // 添加 no-store 和 no-cache 头,防止 Cloudflare 缓存 const response = NextResponse.next(); response.headers.set( "Cache-Control", @@ -57,14 +63,14 @@ export async function middleware(request: NextRequest) { return response; } else if (pathname.startsWith("/api/config/key")) { - // 确保这个路径不被中间件拦截 + return NextResponse.next(); + } else if (pathname.startsWith("/api/init")) { return NextResponse.next(); } return NextResponse.next(); } -// 配置中间件匹配的路由 export const config = { matcher: ["/((?!_next/static|_next/image|favicon.ico).*)"], }; diff --git a/package.json b/package.json index 52cc33862e293663516f2ef4e67b2d19bd793342..1f72a6ea06494394e3a92369d9354adca74a8af2 100644 --- a/package.json +++ b/package.json @@ -1,7 +1,7 @@ { "$schema": "https://json.schemastore.org/package.json", "name": "openwebui-usage-monitor", - "version": "0.3.5", + "version": "0.3.7", "private": true, "scripts": { "dev": "next dev", diff --git a/public/static/favicon.png b/public/static/favicon.png index 389196ca6a364b9e4b7daa0fc13be463b914b251..1d171171f487203cda8857a88c2fb52df2728bd0 100644 Binary files a/public/static/favicon.png and b/public/static/favicon.png differ diff --git a/resources/functions/openwebui_monitor.py b/resources/functions/openwebui_monitor.py index 328a2c898f3fa2cba5b18aa7fa88816806570a40..d8d648496b2f28a8768508f30f20ed66a67752b9 100644 --- a/resources/functions/openwebui_monitor.py +++ b/resources/functions/openwebui_monitor.py @@ -2,20 +2,43 @@ title: Usage Monitor author: VariantConst & OVINC CN git_url: https://github.com/VariantConst/OpenWebUI-Monitor.git -version: 0.3.5 +version: 0.3.6 requirements: httpx license: MIT """ import logging +import time from typing import Dict, Optional - from httpx import AsyncClient from pydantic import BaseModel, Field +import json + logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) +TRANSLATIONS = { + "en": { + "request_failed": "Request failed: {error_msg}", + "insufficient_balance": "Insufficient balance: Current balance `{balance:.4f}`", + "cost": "Cost: ${cost:.4f}", + "balance": "Balance: ${balance:.4f}", + "tokens": "Tokens: {input}+{output}", + "time_spent": "Time: {time:.2f}s", + "tokens_per_sec": "{tokens_per_sec:.2f} T/s", + }, + "zh": { + "request_failed": "请求失败: {error_msg}", + "insufficient_balance": "余额不足: 当前余额 `{balance:.4f}`", + "cost": "费用: ¥{cost:.4f}", + "balance": "余额: ¥{balance:.4f}", + "tokens": "Token: {input}+{output}", + "time_spent": "耗时: {time:.2f}s", + "tokens_per_sec": "{tokens_per_sec:.2f} T/s", + }, +} + class CustomException(Exception): pass @@ -26,25 +49,41 @@ class Filter: api_endpoint: str = Field(default="", description="openwebui-monitor's base url") api_key: str = Field(default="", description="openwebui-monitor's api key") priority: int = Field(default=5, description="filter priority") + language: str = Field(default="zh", description="language (en/zh)") + show_time_spent: bool = Field(default=True, description="show time spent") + show_tokens_per_sec: bool = Field(default=True, description="show tokens per second") + show_cost: bool = Field(default=True, description="show cost") + show_balance: bool = Field(default=True, description="show balance") + show_tokens: bool = Field(default=True, description="show tokens") def __init__(self): self.type = "filter" + self.name = "OpenWebUI Monitor" self.valves = self.Valves() self.outage_map: Dict[str, bool] = {} + self.start_time: Optional[float] = None + + def get_text(self, key: str, **kwargs) -> str: + lang = self.valves.language if self.valves.language in TRANSLATIONS else "en" + text = TRANSLATIONS[lang].get(key, TRANSLATIONS["en"][key]) + return text.format(**kwargs) if kwargs else text + + async def request(self, client: AsyncClient, url: str, headers: dict, json_data: dict): + json_data = json.loads(json.dumps(json_data, default=lambda o: o.dict() if hasattr(o, "dict") else str(o))) - async def request(self, client: AsyncClient, url: str, headers: dict, json: dict): - response = await client.post(url=url, headers=headers, json=json) + response = await client.post(url=url, headers=headers, json=json_data) response.raise_for_status() response_data = response.json() if not response_data.get("success"): - logger.error("[usage_monitor] req monitor failed: %s", response_data) - raise CustomException("calculate usage failed, please contact administrator") + logger.error(self.get_text("request_failed", error_msg=response_data)) + raise CustomException(self.get_text("request_failed", error_msg=response_data)) return response_data async def inlet(self, body: dict, __metadata__: Optional[dict] = None, __user__: Optional[dict] = None) -> dict: __user__ = __user__ or {} __metadata__ = __metadata__ or {} - user_id = __user__["id"] + self.start_time = time.time() + user_id = __user__.get("id", "default") client = AsyncClient() @@ -53,17 +92,16 @@ class Filter: client=client, url=f"{self.valves.api_endpoint}/api/v1/inlet", headers={"Authorization": f"Bearer {self.valves.api_key}"}, - json={"user": __user__, "body": body}, + json_data={"user": __user__, "body": body}, ) self.outage_map[user_id] = response_data.get("balance", 0) <= 0 if self.outage_map[user_id]: - logger.info("[usage_monitor] no balance: %s", user_id) - raise CustomException("no balance, please contact administrator") - + logger.info(self.get_text("insufficient_balance", balance=response_data.get("balance", 0))) + raise CustomException(self.get_text("insufficient_balance", balance=response_data.get("balance", 0))) return body except Exception as err: - logger.exception("[usage_monitor] error calculating usage: %s", err) + logger.exception(self.get_text("request_failed", error_msg=err)) if isinstance(err, CustomException): raise err raise Exception(f"error calculating usage, {err}") from err @@ -76,13 +114,13 @@ class Filter: body: dict, __metadata__: Optional[dict] = None, __user__: Optional[dict] = None, - __event_emitter__: callable = None, + __event_emitter__: Optional[callable] = None, ) -> dict: __user__ = __user__ or {} __metadata__ = __metadata__ or {} - user_id = __user__["id"] + user_id = __user__.get("id", "default") - if self.outage_map[user_id]: + if self.outage_map.get(user_id, False): return body client = AsyncClient() @@ -92,26 +130,32 @@ class Filter: client=client, url=f"{self.valves.api_endpoint}/api/v1/outlet", headers={"Authorization": f"Bearer {self.valves.api_key}"}, - json={"user": __user__, "body": body}, - ) - - # pylint: disable=C0209 - stats = " | ".join( - [ - f"Tokens: {response_data['inputTokens']} + {response_data['outputTokens']}", - "Cost: %.4f" % response_data["totalCost"], - "Balance: %.4f" % response_data["newBalance"], - ] + json_data={"user": __user__, "body": body}, ) - await __event_emitter__({"type": "status", "data": {"description": stats, "done": True}}) + stats_list = [] + if self.valves.show_tokens: + stats_list.append(self.get_text("tokens", input=response_data["inputTokens"], output=response_data["outputTokens"])) + if self.valves.show_cost: + stats_list.append(self.get_text("cost", cost=response_data["totalCost"])) + if self.valves.show_balance: + stats_list.append(self.get_text("balance", balance=response_data["newBalance"])) + if self.start_time and self.valves.show_time_spent: + elapsed = time.time() - self.start_time + stats_list.append(self.get_text("time_spent", time=elapsed)) + if self.valves.show_tokens_per_sec: + tokens_per_sec = (response_data["outputTokens"] / elapsed if elapsed > 0 else 0) + stats_list.append(self.get_text("tokens_per_sec", tokens_per_sec=tokens_per_sec)) + + stats = " | ".join(stats_list) + if __event_emitter__: + await __event_emitter__({"type": "status", "data": {"description": stats, "done": True}}) logger.info("usage_monitor: %s %s", user_id, stats) return body except Exception as err: - logger.exception("[usage_monitor] error calculating usage: %s", err) - raise Exception(f"error calculating usage, {err}") from err - + logger.exception(self.get_text("request_failed", error_msg=err)) + raise Exception(self.get_text("request_failed", error_msg=err)) finally: await client.aclose() diff --git a/scripts/init-db.ts b/scripts/init-db.ts index 799930926ef09ab8685738d9c0996052fced1956..8686419406cab2ad966d71432841ca47355318d3 100644 --- a/scripts/init-db.ts +++ b/scripts/init-db.ts @@ -1,4 +1,4 @@ -import { ensureTablesExist } from "../lib/db"; +import { ensureTablesExist } from "../lib/db/client"; async function init() { try {