Spaces:
Sleeping
Sleeping
| # conda activate rzwl && uvicorn app:app --host 0.0.0.0 --port 7861 --reload | |
| from fastapi import FastAPI, APIRouter, HTTPException, status, File, UploadFile, Query # 导入 File 和 UploadFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| import uuid | |
| import sqlite3 # 导入 sqlite3 模块 | |
| from typing import Optional, List, Dict # 导入类型提示 | |
| import json # 导入 json 模块 | |
| import os # 导入 os 模块 | |
| import urllib.parse # 导入 urllib.parse 用于 URL 编码 | |
| from fastapi.staticfiles import StaticFiles # 导入 StaticFiles | |
| from passlib.context import CryptContext # 导入 CryptContext | |
| import httpx # 导入 httpx 用于异步 HTTP 请求 | |
| import hashlib # 导入 hashlib 用于 SHA1 哈希 | |
| app = FastAPI(max_upload_size=10 * 1024 * 1024) # 设置最大上传大小为 10MB | |
| DATABASE_URL = "ai_edu.db" | |
| # 用于密码哈希的 CryptContext | |
| pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") | |
| def verify_password(plain_password, hashed_password): | |
| return pwd_context.verify(plain_password, hashed_password) | |
| def get_password_hash(password): | |
| return pwd_context.hash(password) | |
| def init_db(): | |
| with sqlite3.connect(DATABASE_URL) as conn: | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS users ( | |
| user_id TEXT PRIMARY KEY, | |
| username TEXT UNIQUE NOT NULL, | |
| password TEXT NOT NULL | |
| ) | |
| """) | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS enrollments ( | |
| enrollment_id TEXT PRIMARY KEY, | |
| name TEXT NOT NULL, | |
| phone TEXT NOT NULL, | |
| company TEXT, | |
| position TEXT, | |
| email TEXT | |
| ) | |
| """) | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS enterprise_orders ( | |
| order_id TEXT PRIMARY KEY, | |
| company_name TEXT NOT NULL, | |
| contact_person TEXT NOT NULL, | |
| contact_phone TEXT NOT NULL, | |
| course_name TEXT NOT NULL, | |
| quantity INTEGER NOT NULL, | |
| order_date TEXT NOT NULL | |
| ) | |
| """) | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS courses ( | |
| course_id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| title TEXT NOT NULL, | |
| description TEXT, | |
| price REAL NOT NULL, | |
| duration_hours INTEGER, | |
| level TEXT | |
| ) | |
| """) | |
| # Add page_info column if it doesn't exist | |
| cursor.execute(""" | |
| PRAGMA table_info(courses); | |
| """) | |
| columns = cursor.fetchall() | |
| column_names = [col[1] for col in columns] | |
| if 'page_info' not in column_names: | |
| cursor.execute(""" | |
| ALTER TABLE courses ADD COLUMN page_info TEXT; | |
| """) | |
| conn.commit() | |
| # 在应用启动时初始化数据库 | |
| async def startup_event(): | |
| init_db() | |
| # 配置 CORS 中间件 | |
| origins = [ | |
| "*" # 允许所有来源,开发环境方便,生产环境应限制为特定域名 | |
| ] | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=origins, | |
| allow_credentials=True, | |
| allow_methods=["*"], # 允许所有方法,包括 OPTIONS | |
| allow_headers=["*"], # 允许所有头部 | |
| ) | |
| api_router = APIRouter(prefix="/api") | |
| class RegisterRequest(BaseModel): | |
| username: str | |
| password: str | |
| class LoginRequest(BaseModel): | |
| username: str | |
| password: str | |
| class EnrollmentIndividualRequest(BaseModel): | |
| name: str | |
| phone: str | |
| company: str | |
| position: str | |
| email: str | |
| # WeChat Configuration (PLACEHOLDERS - REPLACE WITH ACTUAL VALUES OR ENVIRONMENT VARIABLES) | |
| WECHAT_APP_ID = os.getenv("WECHAT_APP_ID", "YOUR_WECHAT_APP_ID") | |
| WECHAT_APP_SECRET = os.getenv("WECHAT_APP_SECRET", "YOUR_WECHAT_APP_SECRET") # Add App Secret | |
| WECHAT_REDIRECT_URI = os.getenv("WECHAT_REDIRECT_URI", "http://localhost:7861/api/auth/wechat/callback") # This should be your frontend callback URL | |
| WECHAT_TOKEN = os.getenv("WECHAT_TOKEN", "YOUR_WECHAT_VERIFICATION_TOKEN") # Add WeChat verification token | |
| class Course(BaseModel): | |
| course_id: Optional[int] = None | |
| title: str | |
| description: Optional[str] = None | |
| price: float = 0 | |
| duration_hours: Optional[int] = 0 | |
| level: Optional[str] = None | |
| page_info: Optional[List[Dict]] = None # 改为接收JSON对象而非字符串 | |
| async def register_user(request: RegisterRequest): | |
| user_id = str(uuid.uuid4()) | |
| try: | |
| with sqlite3.connect(DATABASE_URL) as conn: | |
| cursor = conn.cursor() | |
| hashed_password = get_password_hash(request.password) | |
| cursor.execute( | |
| "INSERT INTO users (user_id, username, password) VALUES (?, ?, ?)", | |
| (user_id, request.username, hashed_password) | |
| ) | |
| conn.commit() | |
| return { | |
| "code": 200, | |
| "message": "注册成功", | |
| "data": { | |
| "user_id": user_id, | |
| "username": request.username | |
| } | |
| } | |
| except sqlite3.IntegrityError: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="用户名已存在" | |
| ) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"注册失败: {str(e)}" | |
| ) | |
| async def login_user(request: LoginRequest): | |
| try: | |
| with sqlite3.connect(DATABASE_URL) as conn: | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| "SELECT user_id, username, password FROM users WHERE username = ?", | |
| (request.username,) | |
| ) | |
| user = cursor.fetchone() | |
| if user: | |
| user_id, username, hashed_password = user | |
| if verify_password(request.password, hashed_password): | |
| return { | |
| "code": 200, | |
| "message": "登录成功", | |
| "data": { | |
| "user_id": user_id, | |
| "username": username | |
| } | |
| } | |
| else: | |
| return { | |
| "code": 401, | |
| "message": "用户名或密码错误", | |
| "data": None | |
| } | |
| else: | |
| return { | |
| "code": 401, | |
| "message": "用户名或密码错误", | |
| "data": None | |
| } | |
| except Exception as e: | |
| # 捕获所有异常,返回 500 错误结构 | |
| print(f"登录失败未知错误: {type(e).__name__}: {e}") | |
| return { | |
| "code": 500, | |
| "message": f"登录失败: {str(e)}", | |
| "data": None | |
| } | |
| async def get_wechat_qrcode(): | |
| """ | |
| 获取微信登录二维码URL。 | |
| """ | |
| # Generate a random state to prevent CSRF attacks | |
| state = str(uuid.uuid4()) | |
| # In a real application, you would store this state in a session or database | |
| # to verify it upon callback. | |
| # URL-encode the redirect_uri | |
| encoded_redirect_uri = urllib.parse.quote_plus(WECHAT_REDIRECT_URI) | |
| # Construct the WeChat QR code login URL | |
| qrcode_url = ( | |
| f"https://open.weixin.qq.com/connect/qrconnect?" | |
| f"appid={WECHAT_APP_ID}&" | |
| f"redirect_uri={encoded_redirect_uri}&" | |
| f"response_type=code&" | |
| f"scope=snsapi_login&" | |
| f"state={state}#wechat_redirect" | |
| ) | |
| return { | |
| "code": 200, | |
| "message": "获取微信二维码URL成功", | |
| "data": { | |
| "qrcode_url": qrcode_url, | |
| "state": state # Return state to frontend for verification | |
| } | |
| } | |
| async def wechat_callback(code: str, state: str): | |
| """ | |
| 微信登录回调接口,用于接收微信授权码并获取access_token。 | |
| """ | |
| # In a real application, you would verify the 'state' parameter | |
| # against the one stored in the user's session to prevent CSRF attacks. | |
| # For this example, we'll just print it. | |
| print(f"Received WeChat callback with code: {code} and state: {state}") | |
| # Exchange code for access_token | |
| token_url = ( | |
| f"https://api.weixin.qq.com/sns/oauth2/access_token?" | |
| f"appid={WECHAT_APP_ID}&" | |
| f"secret={WECHAT_APP_SECRET}&" | |
| f"code={code}&" | |
| f"grant_type=authorization_code" | |
| ) | |
| async with httpx.AsyncClient() as client: | |
| response = await client.get(token_url) | |
| token_data = response.json() | |
| if "errcode" in token_data: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"微信授权失败: {token_data.get('errmsg', '未知错误')}" | |
| ) | |
| # Successfully got access_token and other info | |
| access_token = token_data.get("access_token") | |
| openid = token_data.get("openid") | |
| unionid = token_data.get("unionid") # May not be present if scope is only snsapi_base | |
| # In a real application, you would now use the access_token and openid/unionid | |
| # to log in the user, create a new user, or fetch more user info. | |
| # For example, you might fetch user info: | |
| # userinfo_url = f"https://api.weixin.qq.com/sns/userinfo?access_token={access_token}&openid={openid}&lang=zh_CN" | |
| # user_response = await client.get(userinfo_url) | |
| # user_info = user_response.json() | |
| return { | |
| "code": 200, | |
| "message": "微信登录回调成功", | |
| "data": { | |
| "access_token": access_token, | |
| "openid": openid, | |
| "unionid": unionid, | |
| # "user_info": user_info # Uncomment if fetching user info | |
| } | |
| } | |
| async def wechat_verify(signature: str, timestamp: str, nonce: str, echostr: str): | |
| """ | |
| 微信服务器配置验证接口。 | |
| 用于验证微信服务器的有效性。 | |
| """ | |
| # 1. 将 token、timestamp、nonce 三个参数进行字典序排序 | |
| # 2. 将三个参数字符串拼接成一个字符串进行 sha1 加密 | |
| # 3. 获得加密后的字符串可与 signature 对比,标识该请求来源于微信 | |
| # Note: WECHAT_TOKEN should be the token you set in WeChat Official Account/Mini Program backend. | |
| data = [WECHAT_TOKEN, timestamp, nonce] | |
| data.sort() | |
| temp_str = "".join(data) | |
| sha1 = hashlib.sha1(temp_str.encode('utf-8')).hexdigest() | |
| if sha1 == signature: | |
| return echostr | |
| else: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="微信验证失败" | |
| ) | |
| async def enroll_individual(request: EnrollmentIndividualRequest): | |
| enrollment_id = str(uuid.uuid4()) | |
| try: | |
| with sqlite3.connect(DATABASE_URL) as conn: | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| "INSERT INTO enrollments (enrollment_id, name, phone, company, position, email) VALUES (?, ?, ?, ?, ?, ?)", | |
| (enrollment_id, request.name, request.phone, request.company, request.position, request.email) | |
| ) | |
| conn.commit() | |
| return { | |
| "code": 200, | |
| "message": "个人报名成功", | |
| "data": { | |
| "enrollment_id": enrollment_id, | |
| "name": request.name, | |
| "phone": request.phone, | |
| "company": request.company, | |
| "position": request.position, | |
| "email": request.email | |
| } | |
| } | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"报名失败: {str(e)}" | |
| ) | |
| async def get_enterprise_orders(): | |
| try: | |
| with sqlite3.connect(DATABASE_URL) as conn: | |
| conn.row_factory = sqlite3.Row # 允许通过列名访问数据 | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT * FROM enterprise_orders") | |
| orders = cursor.fetchall() | |
| # 将 Row 对象转换为字典列表 | |
| orders_list = [dict(order) for order in orders] | |
| return { | |
| "code": 200, | |
| "message": "获取企业订单成功", | |
| "data": orders_list | |
| } | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"获取企业订单失败: {str(e)}" | |
| ) | |
| async def get_individual_orders(): | |
| try: | |
| with sqlite3.connect(DATABASE_URL) as conn: | |
| conn.row_factory = sqlite3.Row # 允许通过列名访问数据 | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT enrollment_id, name, phone, company, position, email FROM enrollments") | |
| enrollments = cursor.fetchall() | |
| # 将 Row 对象转换为字典列表 | |
| enrollments_list = [dict(enrollment) for enrollment in enrollments] | |
| return { | |
| "code": 200, | |
| "message": "获取个人报名订单成功", | |
| "data": enrollments_list | |
| } | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"获取个人报名订单失败: {str(e)}" | |
| ) | |
| async def save_course(course: Course): | |
| print('\n\n\nsave_course') | |
| try: | |
| with sqlite3.connect(DATABASE_URL) as conn: | |
| cursor = conn.cursor() | |
| print('\n\n\n\n') | |
| print(course) | |
| # Convert page_info to JSON string before saving | |
| page_info_json = json.dumps(course.page_info) if course.page_info is not None else None | |
| if course.course_id is not None: | |
| # 更新现有课程 | |
| cursor.execute( | |
| "UPDATE courses SET title = ?, description = ?, price = ?, duration_hours = ?, level = ?, page_info = ? WHERE course_id = ?", | |
| (course.title, course.description, course.price, course.duration_hours, course.level, page_info_json, course.course_id) | |
| ) | |
| conn.commit() | |
| if cursor.rowcount == 0: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail="未找到该课程进行更新" | |
| ) | |
| return { | |
| "code": 200, | |
| "message": "课程更新成功", | |
| "data": { | |
| "course_id": course.course_id, | |
| **course.dict(exclude_unset=True) | |
| } | |
| } | |
| else: | |
| # 添加新课程 | |
| cursor.execute( | |
| "INSERT INTO courses (title, description, price, duration_hours, level, page_info) VALUES (?, ?, ?, ?, ?, ?)", | |
| (course.title, course.description, course.price, course.duration_hours, course.level, page_info_json) | |
| ) | |
| conn.commit() | |
| course_id = cursor.lastrowid | |
| return { | |
| "code": 200, | |
| "message": "课程添加成功", | |
| "data": { | |
| "course_id": course_id, | |
| **course.dict(exclude_unset=True) | |
| } | |
| } | |
| except HTTPException as e: | |
| raise e | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"保存课程失败: {str(e)}" | |
| ) | |
| async def get_course_details(course_id: int): | |
| try: | |
| with sqlite3.connect(DATABASE_URL) as conn: | |
| conn.row_factory = sqlite3.Row | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT * FROM courses WHERE course_id = ?", (course_id,)) | |
| course = cursor.fetchone() | |
| if course: | |
| course_dict = dict(course) | |
| if course_dict['page_info']: | |
| course_dict['page_info'] = json.loads(course_dict['page_info']) | |
| return { | |
| "code": 200, | |
| "message": "获取课程详情成功", | |
| "data": course_dict | |
| } | |
| else: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail="未找到该课程" | |
| ) | |
| except HTTPException as e: | |
| raise e | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"获取课程详情失败: {str(e)}" | |
| ) | |
| async def delete_individual_enrollment(enrollment_id: str): | |
| try: | |
| with sqlite3.connect(DATABASE_URL) as conn: | |
| cursor = conn.cursor() | |
| cursor.execute("DELETE FROM enrollments WHERE enrollment_id = ?", (enrollment_id,)) | |
| conn.commit() | |
| if cursor.rowcount == 0: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail="未找到该报名数据" | |
| ) | |
| return { | |
| "code": 200, | |
| "message": "报名数据删除成功", | |
| "data": {"enrollment_id": enrollment_id} | |
| } | |
| except HTTPException as e: | |
| raise e | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"删除报名数据失败: {str(e)}" | |
| ) | |
| async def upload_image(image: UploadFile = File(...)): # 将参数名从 file 改为 image,并恢复 File(...) | |
| os.makedirs("upload/images", exist_ok=True) # 自动创建目录 | |
| file_location = f"upload/images/{image.filename}" # 使用 image.filename | |
| try: | |
| with open(file_location, "wb+") as buffer: # 改为 wb+ | |
| content = await image.read() # 异步读取 | |
| buffer.write(content) | |
| return { | |
| "code": 200, | |
| "message": "图片上传成功", # 增加 message 字段 | |
| "data": { | |
| "filename": image.filename, # 使用 image.filename | |
| "url": f"/upload/images/{image.filename}" # 使用 image.filename | |
| } | |
| } | |
| except Exception as e: | |
| print(f"图片上传失败: {str(e)}") # 输出详细错误 | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, # 恢复 status.HTTP_500_INTERNAL_SERVER_ERROR | |
| detail=f"图片上传失败: {str(e)}" | |
| ) | |
| # 挂载静态文件目录 | |
| app.mount("/upload", StaticFiles(directory="upload"), name="upload") | |
| def greet_json(): | |
| return {"Hello": "World!"} | |
| async def wechat_token_check( | |
| signature: str = Query(..., description="微信加密签名"), | |
| timestamp: str = Query(..., description="时间戳"), | |
| nonce: str = Query(..., description="随机数"), | |
| echostr: str = Query(..., description="验证字符串"), | |
| ): | |
| # 1. 参数排序与拼接 | |
| tmp_list = sorted([WECHAT_TOKEN, timestamp, nonce]) | |
| tmp_str = "".join(tmp_list) | |
| # 2. SHA1加密生成签名 | |
| sha1 = hashlib.sha1() | |
| sha1.update(tmp_str.encode("utf-8")) | |
| hashcode = sha1.hexdigest() | |
| # 3. 调试日志(可选) | |
| # print(f"token: {WECHAT_TOKEN}, timestamp: {timestamp}, nonce: {nonce}") | |
| # print(f"生成签名: {hashcode}, 微信签名: {signature}") | |
| import time | |
| current_time = int(time.time()) | |
| if abs(current_time - int(timestamp)) > 300: # 5分钟=300秒 | |
| raise HTTPException(401, "时间戳过期") | |
| # 4. 校验签名并返回结果 | |
| if hashcode == signature: | |
| print("微信服务器验证成功: echostr",echostr) | |
| return echostr # 校验成功,返回echostr | |
| else: | |
| # 校验失败返回401错误[4](@ref)[5](@ref) | |
| raise HTTPException( | |
| status_code=401, | |
| detail="签名验证失败", | |
| headers={"WWW-Authenticate": "Bearer"} | |
| ) | |
| app.include_router(api_router) | |