# conda activate rzwl && uvicorn app:app --host 0.0.0.0 --port 7861 --reload from fastapi import FastAPI, APIRouter, HTTPException, status, File, UploadFile, Query # 导入 File 和 UploadFile from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field import uuid import sqlite3 # 导入 sqlite3 模块 from typing import Optional, List, Dict # 导入类型提示 import json # 导入 json 模块 import os # 导入 os 模块 import urllib.parse # 导入 urllib.parse 用于 URL 编码 from fastapi.staticfiles import StaticFiles # 导入 StaticFiles from passlib.context import CryptContext # 导入 CryptContext import httpx # 导入 httpx 用于异步 HTTP 请求 import hashlib # 导入 hashlib 用于 SHA1 哈希 app = FastAPI(max_upload_size=10 * 1024 * 1024) # 设置最大上传大小为 10MB DATABASE_URL = "ai_edu.db" # 用于密码哈希的 CryptContext pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") def verify_password(plain_password, hashed_password): return pwd_context.verify(plain_password, hashed_password) def get_password_hash(password): return pwd_context.hash(password) def init_db(): with sqlite3.connect(DATABASE_URL) as conn: cursor = conn.cursor() cursor.execute(""" CREATE TABLE IF NOT EXISTS users ( user_id TEXT PRIMARY KEY, username TEXT UNIQUE NOT NULL, password TEXT NOT NULL ) """) cursor.execute(""" CREATE TABLE IF NOT EXISTS enrollments ( enrollment_id TEXT PRIMARY KEY, name TEXT NOT NULL, phone TEXT NOT NULL, company TEXT, position TEXT, email TEXT ) """) cursor.execute(""" CREATE TABLE IF NOT EXISTS enterprise_orders ( order_id TEXT PRIMARY KEY, company_name TEXT NOT NULL, contact_person TEXT NOT NULL, contact_phone TEXT NOT NULL, course_name TEXT NOT NULL, quantity INTEGER NOT NULL, order_date TEXT NOT NULL ) """) cursor.execute(""" CREATE TABLE IF NOT EXISTS courses ( course_id INTEGER PRIMARY KEY AUTOINCREMENT, title TEXT NOT NULL, description TEXT, price REAL NOT NULL, duration_hours INTEGER, level TEXT ) """) # Add page_info column if it doesn't exist cursor.execute(""" PRAGMA table_info(courses); """) columns = cursor.fetchall() column_names = [col[1] for col in columns] if 'page_info' not in column_names: cursor.execute(""" ALTER TABLE courses ADD COLUMN page_info TEXT; """) conn.commit() # 在应用启动时初始化数据库 @app.on_event("startup") async def startup_event(): init_db() # 配置 CORS 中间件 origins = [ "*" # 允许所有来源,开发环境方便,生产环境应限制为特定域名 ] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], # 允许所有方法,包括 OPTIONS allow_headers=["*"], # 允许所有头部 ) api_router = APIRouter(prefix="/api") class RegisterRequest(BaseModel): username: str password: str class LoginRequest(BaseModel): username: str password: str class EnrollmentIndividualRequest(BaseModel): name: str phone: str company: str position: str email: str # WeChat Configuration (PLACEHOLDERS - REPLACE WITH ACTUAL VALUES OR ENVIRONMENT VARIABLES) WECHAT_APP_ID = os.getenv("WECHAT_APP_ID", "YOUR_WECHAT_APP_ID") WECHAT_APP_SECRET = os.getenv("WECHAT_APP_SECRET", "YOUR_WECHAT_APP_SECRET") # Add App Secret WECHAT_REDIRECT_URI = os.getenv("WECHAT_REDIRECT_URI", "http://localhost:7861/api/auth/wechat/callback") # This should be your frontend callback URL WECHAT_TOKEN = os.getenv("WECHAT_TOKEN", "YOUR_WECHAT_VERIFICATION_TOKEN") # Add WeChat verification token class Course(BaseModel): course_id: Optional[int] = None title: str description: Optional[str] = None price: float = 0 duration_hours: Optional[int] = 0 level: Optional[str] = None page_info: Optional[List[Dict]] = None # 改为接收JSON对象而非字符串 @api_router.post("/auth/register") async def register_user(request: RegisterRequest): user_id = str(uuid.uuid4()) try: with sqlite3.connect(DATABASE_URL) as conn: cursor = conn.cursor() hashed_password = get_password_hash(request.password) cursor.execute( "INSERT INTO users (user_id, username, password) VALUES (?, ?, ?)", (user_id, request.username, hashed_password) ) conn.commit() return { "code": 200, "message": "注册成功", "data": { "user_id": user_id, "username": request.username } } except sqlite3.IntegrityError: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="用户名已存在" ) except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"注册失败: {str(e)}" ) @api_router.post("/auth/login") async def login_user(request: LoginRequest): try: with sqlite3.connect(DATABASE_URL) as conn: cursor = conn.cursor() cursor.execute( "SELECT user_id, username, password FROM users WHERE username = ?", (request.username,) ) user = cursor.fetchone() if user: user_id, username, hashed_password = user if verify_password(request.password, hashed_password): return { "code": 200, "message": "登录成功", "data": { "user_id": user_id, "username": username } } else: return { "code": 401, "message": "用户名或密码错误", "data": None } else: return { "code": 401, "message": "用户名或密码错误", "data": None } except Exception as e: # 捕获所有异常,返回 500 错误结构 print(f"登录失败未知错误: {type(e).__name__}: {e}") return { "code": 500, "message": f"登录失败: {str(e)}", "data": None } @api_router.get("/auth/wechat/qrcode") async def get_wechat_qrcode(): """ 获取微信登录二维码URL。 """ # Generate a random state to prevent CSRF attacks state = str(uuid.uuid4()) # In a real application, you would store this state in a session or database # to verify it upon callback. # URL-encode the redirect_uri encoded_redirect_uri = urllib.parse.quote_plus(WECHAT_REDIRECT_URI) # Construct the WeChat QR code login URL qrcode_url = ( f"https://open.weixin.qq.com/connect/qrconnect?" f"appid={WECHAT_APP_ID}&" f"redirect_uri={encoded_redirect_uri}&" f"response_type=code&" f"scope=snsapi_login&" f"state={state}#wechat_redirect" ) return { "code": 200, "message": "获取微信二维码URL成功", "data": { "qrcode_url": qrcode_url, "state": state # Return state to frontend for verification } } @api_router.get("/auth/wechat/callback") async def wechat_callback(code: str, state: str): """ 微信登录回调接口,用于接收微信授权码并获取access_token。 """ # In a real application, you would verify the 'state' parameter # against the one stored in the user's session to prevent CSRF attacks. # For this example, we'll just print it. print(f"Received WeChat callback with code: {code} and state: {state}") # Exchange code for access_token token_url = ( f"https://api.weixin.qq.com/sns/oauth2/access_token?" f"appid={WECHAT_APP_ID}&" f"secret={WECHAT_APP_SECRET}&" f"code={code}&" f"grant_type=authorization_code" ) async with httpx.AsyncClient() as client: response = await client.get(token_url) token_data = response.json() if "errcode" in token_data: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"微信授权失败: {token_data.get('errmsg', '未知错误')}" ) # Successfully got access_token and other info access_token = token_data.get("access_token") openid = token_data.get("openid") unionid = token_data.get("unionid") # May not be present if scope is only snsapi_base # In a real application, you would now use the access_token and openid/unionid # to log in the user, create a new user, or fetch more user info. # For example, you might fetch user info: # userinfo_url = f"https://api.weixin.qq.com/sns/userinfo?access_token={access_token}&openid={openid}&lang=zh_CN" # user_response = await client.get(userinfo_url) # user_info = user_response.json() return { "code": 200, "message": "微信登录回调成功", "data": { "access_token": access_token, "openid": openid, "unionid": unionid, # "user_info": user_info # Uncomment if fetching user info } } @api_router.get("/wechat/verify") async def wechat_verify(signature: str, timestamp: str, nonce: str, echostr: str): """ 微信服务器配置验证接口。 用于验证微信服务器的有效性。 """ # 1. 将 token、timestamp、nonce 三个参数进行字典序排序 # 2. 将三个参数字符串拼接成一个字符串进行 sha1 加密 # 3. 获得加密后的字符串可与 signature 对比,标识该请求来源于微信 # Note: WECHAT_TOKEN should be the token you set in WeChat Official Account/Mini Program backend. data = [WECHAT_TOKEN, timestamp, nonce] data.sort() temp_str = "".join(data) sha1 = hashlib.sha1(temp_str.encode('utf-8')).hexdigest() if sha1 == signature: return echostr else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="微信验证失败" ) @api_router.post("/enrollment/individual") async def enroll_individual(request: EnrollmentIndividualRequest): enrollment_id = str(uuid.uuid4()) try: with sqlite3.connect(DATABASE_URL) as conn: cursor = conn.cursor() cursor.execute( "INSERT INTO enrollments (enrollment_id, name, phone, company, position, email) VALUES (?, ?, ?, ?, ?, ?)", (enrollment_id, request.name, request.phone, request.company, request.position, request.email) ) conn.commit() return { "code": 200, "message": "个人报名成功", "data": { "enrollment_id": enrollment_id, "name": request.name, "phone": request.phone, "company": request.company, "position": request.position, "email": request.email } } except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"报名失败: {str(e)}" ) @api_router.get("/orders/enterprise") async def get_enterprise_orders(): try: with sqlite3.connect(DATABASE_URL) as conn: conn.row_factory = sqlite3.Row # 允许通过列名访问数据 cursor = conn.cursor() cursor.execute("SELECT * FROM enterprise_orders") orders = cursor.fetchall() # 将 Row 对象转换为字典列表 orders_list = [dict(order) for order in orders] return { "code": 200, "message": "获取企业订单成功", "data": orders_list } except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取企业订单失败: {str(e)}" ) @api_router.get("/orders/individual") async def get_individual_orders(): try: with sqlite3.connect(DATABASE_URL) as conn: conn.row_factory = sqlite3.Row # 允许通过列名访问数据 cursor = conn.cursor() cursor.execute("SELECT enrollment_id, name, phone, company, position, email FROM enrollments") enrollments = cursor.fetchall() # 将 Row 对象转换为字典列表 enrollments_list = [dict(enrollment) for enrollment in enrollments] return { "code": 200, "message": "获取个人报名订单成功", "data": enrollments_list } except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取个人报名订单失败: {str(e)}" ) @api_router.post("/course/save") async def save_course(course: Course): print('\n\n\nsave_course') try: with sqlite3.connect(DATABASE_URL) as conn: cursor = conn.cursor() print('\n\n\n\n') print(course) # Convert page_info to JSON string before saving page_info_json = json.dumps(course.page_info) if course.page_info is not None else None if course.course_id is not None: # 更新现有课程 cursor.execute( "UPDATE courses SET title = ?, description = ?, price = ?, duration_hours = ?, level = ?, page_info = ? WHERE course_id = ?", (course.title, course.description, course.price, course.duration_hours, course.level, page_info_json, course.course_id) ) conn.commit() if cursor.rowcount == 0: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="未找到该课程进行更新" ) return { "code": 200, "message": "课程更新成功", "data": { "course_id": course.course_id, **course.dict(exclude_unset=True) } } else: # 添加新课程 cursor.execute( "INSERT INTO courses (title, description, price, duration_hours, level, page_info) VALUES (?, ?, ?, ?, ?, ?)", (course.title, course.description, course.price, course.duration_hours, course.level, page_info_json) ) conn.commit() course_id = cursor.lastrowid return { "code": 200, "message": "课程添加成功", "data": { "course_id": course_id, **course.dict(exclude_unset=True) } } except HTTPException as e: raise e except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"保存课程失败: {str(e)}" ) @api_router.get("/course/details/{course_id}") async def get_course_details(course_id: int): try: with sqlite3.connect(DATABASE_URL) as conn: conn.row_factory = sqlite3.Row cursor = conn.cursor() cursor.execute("SELECT * FROM courses WHERE course_id = ?", (course_id,)) course = cursor.fetchone() if course: course_dict = dict(course) if course_dict['page_info']: course_dict['page_info'] = json.loads(course_dict['page_info']) return { "code": 200, "message": "获取课程详情成功", "data": course_dict } else: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="未找到该课程" ) except HTTPException as e: raise e except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取课程详情失败: {str(e)}" ) @api_router.delete("/enrollment/individual/{enrollment_id}") async def delete_individual_enrollment(enrollment_id: str): try: with sqlite3.connect(DATABASE_URL) as conn: cursor = conn.cursor() cursor.execute("DELETE FROM enrollments WHERE enrollment_id = ?", (enrollment_id,)) conn.commit() if cursor.rowcount == 0: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="未找到该报名数据" ) return { "code": 200, "message": "报名数据删除成功", "data": {"enrollment_id": enrollment_id} } except HTTPException as e: raise e except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"删除报名数据失败: {str(e)}" ) @api_router.post("/upload/images") async def upload_image(image: UploadFile = File(...)): # 将参数名从 file 改为 image,并恢复 File(...) os.makedirs("upload/images", exist_ok=True) # 自动创建目录 file_location = f"upload/images/{image.filename}" # 使用 image.filename try: with open(file_location, "wb+") as buffer: # 改为 wb+ content = await image.read() # 异步读取 buffer.write(content) return { "code": 200, "message": "图片上传成功", # 增加 message 字段 "data": { "filename": image.filename, # 使用 image.filename "url": f"/upload/images/{image.filename}" # 使用 image.filename } } except Exception as e: print(f"图片上传失败: {str(e)}") # 输出详细错误 raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, # 恢复 status.HTTP_500_INTERNAL_SERVER_ERROR detail=f"图片上传失败: {str(e)}" ) # 挂载静态文件目录 app.mount("/upload", StaticFiles(directory="upload"), name="upload") @app.get("/") def greet_json(): return {"Hello": "World!"} @app.get("/token") async def wechat_token_check( signature: str = Query(..., description="微信加密签名"), timestamp: str = Query(..., description="时间戳"), nonce: str = Query(..., description="随机数"), echostr: str = Query(..., description="验证字符串"), ): # 1. 参数排序与拼接 tmp_list = sorted([WECHAT_TOKEN, timestamp, nonce]) tmp_str = "".join(tmp_list) # 2. SHA1加密生成签名 sha1 = hashlib.sha1() sha1.update(tmp_str.encode("utf-8")) hashcode = sha1.hexdigest() # 3. 调试日志(可选) # print(f"token: {WECHAT_TOKEN}, timestamp: {timestamp}, nonce: {nonce}") # print(f"生成签名: {hashcode}, 微信签名: {signature}") import time current_time = int(time.time()) if abs(current_time - int(timestamp)) > 300: # 5分钟=300秒 raise HTTPException(401, "时间戳过期") # 4. 校验签名并返回结果 if hashcode == signature: print("微信服务器验证成功: echostr",echostr) return echostr # 校验成功,返回echostr else: # 校验失败返回401错误[4](@ref)[5](@ref) raise HTTPException( status_code=401, detail="签名验证失败", headers={"WWW-Authenticate": "Bearer"} ) app.include_router(api_router)