ai-edu-api / app.py
geqintan's picture
update
bb052b9
# conda activate rzwl && uvicorn app:app --host 0.0.0.0 --port 7861 --reload
from fastapi import FastAPI, APIRouter, HTTPException, status, File, UploadFile, Query # 导入 File 和 UploadFile
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import uuid
import sqlite3 # 导入 sqlite3 模块
from typing import Optional, List, Dict # 导入类型提示
import json # 导入 json 模块
import os # 导入 os 模块
import urllib.parse # 导入 urllib.parse 用于 URL 编码
from fastapi.staticfiles import StaticFiles # 导入 StaticFiles
from passlib.context import CryptContext # 导入 CryptContext
import httpx # 导入 httpx 用于异步 HTTP 请求
import hashlib # 导入 hashlib 用于 SHA1 哈希
app = FastAPI(max_upload_size=10 * 1024 * 1024) # 设置最大上传大小为 10MB
DATABASE_URL = "ai_edu.db"
# 用于密码哈希的 CryptContext
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def verify_password(plain_password, hashed_password):
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password):
return pwd_context.hash(password)
def init_db():
with sqlite3.connect(DATABASE_URL) as conn:
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS users (
user_id TEXT PRIMARY KEY,
username TEXT UNIQUE NOT NULL,
password TEXT NOT NULL
)
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS enrollments (
enrollment_id TEXT PRIMARY KEY,
name TEXT NOT NULL,
phone TEXT NOT NULL,
company TEXT,
position TEXT,
email TEXT
)
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS enterprise_orders (
order_id TEXT PRIMARY KEY,
company_name TEXT NOT NULL,
contact_person TEXT NOT NULL,
contact_phone TEXT NOT NULL,
course_name TEXT NOT NULL,
quantity INTEGER NOT NULL,
order_date TEXT NOT NULL
)
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS courses (
course_id INTEGER PRIMARY KEY AUTOINCREMENT,
title TEXT NOT NULL,
description TEXT,
price REAL NOT NULL,
duration_hours INTEGER,
level TEXT
)
""")
# Add page_info column if it doesn't exist
cursor.execute("""
PRAGMA table_info(courses);
""")
columns = cursor.fetchall()
column_names = [col[1] for col in columns]
if 'page_info' not in column_names:
cursor.execute("""
ALTER TABLE courses ADD COLUMN page_info TEXT;
""")
conn.commit()
# 在应用启动时初始化数据库
@app.on_event("startup")
async def startup_event():
init_db()
# 配置 CORS 中间件
origins = [
"*" # 允许所有来源,开发环境方便,生产环境应限制为特定域名
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"], # 允许所有方法,包括 OPTIONS
allow_headers=["*"], # 允许所有头部
)
api_router = APIRouter(prefix="/api")
class RegisterRequest(BaseModel):
username: str
password: str
class LoginRequest(BaseModel):
username: str
password: str
class EnrollmentIndividualRequest(BaseModel):
name: str
phone: str
company: str
position: str
email: str
# WeChat Configuration (PLACEHOLDERS - REPLACE WITH ACTUAL VALUES OR ENVIRONMENT VARIABLES)
WECHAT_APP_ID = os.getenv("WECHAT_APP_ID", "YOUR_WECHAT_APP_ID")
WECHAT_APP_SECRET = os.getenv("WECHAT_APP_SECRET", "YOUR_WECHAT_APP_SECRET") # Add App Secret
WECHAT_REDIRECT_URI = os.getenv("WECHAT_REDIRECT_URI", "http://localhost:7861/api/auth/wechat/callback") # This should be your frontend callback URL
WECHAT_TOKEN = os.getenv("WECHAT_TOKEN", "YOUR_WECHAT_VERIFICATION_TOKEN") # Add WeChat verification token
class Course(BaseModel):
course_id: Optional[int] = None
title: str
description: Optional[str] = None
price: float = 0
duration_hours: Optional[int] = 0
level: Optional[str] = None
page_info: Optional[List[Dict]] = None # 改为接收JSON对象而非字符串
@api_router.post("/auth/register")
async def register_user(request: RegisterRequest):
user_id = str(uuid.uuid4())
try:
with sqlite3.connect(DATABASE_URL) as conn:
cursor = conn.cursor()
hashed_password = get_password_hash(request.password)
cursor.execute(
"INSERT INTO users (user_id, username, password) VALUES (?, ?, ?)",
(user_id, request.username, hashed_password)
)
conn.commit()
return {
"code": 200,
"message": "注册成功",
"data": {
"user_id": user_id,
"username": request.username
}
}
except sqlite3.IntegrityError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="用户名已存在"
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"注册失败: {str(e)}"
)
@api_router.post("/auth/login")
async def login_user(request: LoginRequest):
try:
with sqlite3.connect(DATABASE_URL) as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT user_id, username, password FROM users WHERE username = ?",
(request.username,)
)
user = cursor.fetchone()
if user:
user_id, username, hashed_password = user
if verify_password(request.password, hashed_password):
return {
"code": 200,
"message": "登录成功",
"data": {
"user_id": user_id,
"username": username
}
}
else:
return {
"code": 401,
"message": "用户名或密码错误",
"data": None
}
else:
return {
"code": 401,
"message": "用户名或密码错误",
"data": None
}
except Exception as e:
# 捕获所有异常,返回 500 错误结构
print(f"登录失败未知错误: {type(e).__name__}: {e}")
return {
"code": 500,
"message": f"登录失败: {str(e)}",
"data": None
}
@api_router.get("/auth/wechat/qrcode")
async def get_wechat_qrcode():
"""
获取微信登录二维码URL。
"""
# Generate a random state to prevent CSRF attacks
state = str(uuid.uuid4())
# In a real application, you would store this state in a session or database
# to verify it upon callback.
# URL-encode the redirect_uri
encoded_redirect_uri = urllib.parse.quote_plus(WECHAT_REDIRECT_URI)
# Construct the WeChat QR code login URL
qrcode_url = (
f"https://open.weixin.qq.com/connect/qrconnect?"
f"appid={WECHAT_APP_ID}&"
f"redirect_uri={encoded_redirect_uri}&"
f"response_type=code&"
f"scope=snsapi_login&"
f"state={state}#wechat_redirect"
)
return {
"code": 200,
"message": "获取微信二维码URL成功",
"data": {
"qrcode_url": qrcode_url,
"state": state # Return state to frontend for verification
}
}
@api_router.get("/auth/wechat/callback")
async def wechat_callback(code: str, state: str):
"""
微信登录回调接口,用于接收微信授权码并获取access_token。
"""
# In a real application, you would verify the 'state' parameter
# against the one stored in the user's session to prevent CSRF attacks.
# For this example, we'll just print it.
print(f"Received WeChat callback with code: {code} and state: {state}")
# Exchange code for access_token
token_url = (
f"https://api.weixin.qq.com/sns/oauth2/access_token?"
f"appid={WECHAT_APP_ID}&"
f"secret={WECHAT_APP_SECRET}&"
f"code={code}&"
f"grant_type=authorization_code"
)
async with httpx.AsyncClient() as client:
response = await client.get(token_url)
token_data = response.json()
if "errcode" in token_data:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"微信授权失败: {token_data.get('errmsg', '未知错误')}"
)
# Successfully got access_token and other info
access_token = token_data.get("access_token")
openid = token_data.get("openid")
unionid = token_data.get("unionid") # May not be present if scope is only snsapi_base
# In a real application, you would now use the access_token and openid/unionid
# to log in the user, create a new user, or fetch more user info.
# For example, you might fetch user info:
# userinfo_url = f"https://api.weixin.qq.com/sns/userinfo?access_token={access_token}&openid={openid}&lang=zh_CN"
# user_response = await client.get(userinfo_url)
# user_info = user_response.json()
return {
"code": 200,
"message": "微信登录回调成功",
"data": {
"access_token": access_token,
"openid": openid,
"unionid": unionid,
# "user_info": user_info # Uncomment if fetching user info
}
}
@api_router.get("/wechat/verify")
async def wechat_verify(signature: str, timestamp: str, nonce: str, echostr: str):
"""
微信服务器配置验证接口。
用于验证微信服务器的有效性。
"""
# 1. 将 token、timestamp、nonce 三个参数进行字典序排序
# 2. 将三个参数字符串拼接成一个字符串进行 sha1 加密
# 3. 获得加密后的字符串可与 signature 对比,标识该请求来源于微信
# Note: WECHAT_TOKEN should be the token you set in WeChat Official Account/Mini Program backend.
data = [WECHAT_TOKEN, timestamp, nonce]
data.sort()
temp_str = "".join(data)
sha1 = hashlib.sha1(temp_str.encode('utf-8')).hexdigest()
if sha1 == signature:
return echostr
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="微信验证失败"
)
@api_router.post("/enrollment/individual")
async def enroll_individual(request: EnrollmentIndividualRequest):
enrollment_id = str(uuid.uuid4())
try:
with sqlite3.connect(DATABASE_URL) as conn:
cursor = conn.cursor()
cursor.execute(
"INSERT INTO enrollments (enrollment_id, name, phone, company, position, email) VALUES (?, ?, ?, ?, ?, ?)",
(enrollment_id, request.name, request.phone, request.company, request.position, request.email)
)
conn.commit()
return {
"code": 200,
"message": "个人报名成功",
"data": {
"enrollment_id": enrollment_id,
"name": request.name,
"phone": request.phone,
"company": request.company,
"position": request.position,
"email": request.email
}
}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"报名失败: {str(e)}"
)
@api_router.get("/orders/enterprise")
async def get_enterprise_orders():
try:
with sqlite3.connect(DATABASE_URL) as conn:
conn.row_factory = sqlite3.Row # 允许通过列名访问数据
cursor = conn.cursor()
cursor.execute("SELECT * FROM enterprise_orders")
orders = cursor.fetchall()
# 将 Row 对象转换为字典列表
orders_list = [dict(order) for order in orders]
return {
"code": 200,
"message": "获取企业订单成功",
"data": orders_list
}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取企业订单失败: {str(e)}"
)
@api_router.get("/orders/individual")
async def get_individual_orders():
try:
with sqlite3.connect(DATABASE_URL) as conn:
conn.row_factory = sqlite3.Row # 允许通过列名访问数据
cursor = conn.cursor()
cursor.execute("SELECT enrollment_id, name, phone, company, position, email FROM enrollments")
enrollments = cursor.fetchall()
# 将 Row 对象转换为字典列表
enrollments_list = [dict(enrollment) for enrollment in enrollments]
return {
"code": 200,
"message": "获取个人报名订单成功",
"data": enrollments_list
}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取个人报名订单失败: {str(e)}"
)
@api_router.post("/course/save")
async def save_course(course: Course):
print('\n\n\nsave_course')
try:
with sqlite3.connect(DATABASE_URL) as conn:
cursor = conn.cursor()
print('\n\n\n\n')
print(course)
# Convert page_info to JSON string before saving
page_info_json = json.dumps(course.page_info) if course.page_info is not None else None
if course.course_id is not None:
# 更新现有课程
cursor.execute(
"UPDATE courses SET title = ?, description = ?, price = ?, duration_hours = ?, level = ?, page_info = ? WHERE course_id = ?",
(course.title, course.description, course.price, course.duration_hours, course.level, page_info_json, course.course_id)
)
conn.commit()
if cursor.rowcount == 0:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="未找到该课程进行更新"
)
return {
"code": 200,
"message": "课程更新成功",
"data": {
"course_id": course.course_id,
**course.dict(exclude_unset=True)
}
}
else:
# 添加新课程
cursor.execute(
"INSERT INTO courses (title, description, price, duration_hours, level, page_info) VALUES (?, ?, ?, ?, ?, ?)",
(course.title, course.description, course.price, course.duration_hours, course.level, page_info_json)
)
conn.commit()
course_id = cursor.lastrowid
return {
"code": 200,
"message": "课程添加成功",
"data": {
"course_id": course_id,
**course.dict(exclude_unset=True)
}
}
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"保存课程失败: {str(e)}"
)
@api_router.get("/course/details/{course_id}")
async def get_course_details(course_id: int):
try:
with sqlite3.connect(DATABASE_URL) as conn:
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
cursor.execute("SELECT * FROM courses WHERE course_id = ?", (course_id,))
course = cursor.fetchone()
if course:
course_dict = dict(course)
if course_dict['page_info']:
course_dict['page_info'] = json.loads(course_dict['page_info'])
return {
"code": 200,
"message": "获取课程详情成功",
"data": course_dict
}
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="未找到该课程"
)
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取课程详情失败: {str(e)}"
)
@api_router.delete("/enrollment/individual/{enrollment_id}")
async def delete_individual_enrollment(enrollment_id: str):
try:
with sqlite3.connect(DATABASE_URL) as conn:
cursor = conn.cursor()
cursor.execute("DELETE FROM enrollments WHERE enrollment_id = ?", (enrollment_id,))
conn.commit()
if cursor.rowcount == 0:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="未找到该报名数据"
)
return {
"code": 200,
"message": "报名数据删除成功",
"data": {"enrollment_id": enrollment_id}
}
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"删除报名数据失败: {str(e)}"
)
@api_router.post("/upload/images")
async def upload_image(image: UploadFile = File(...)): # 将参数名从 file 改为 image,并恢复 File(...)
os.makedirs("upload/images", exist_ok=True) # 自动创建目录
file_location = f"upload/images/{image.filename}" # 使用 image.filename
try:
with open(file_location, "wb+") as buffer: # 改为 wb+
content = await image.read() # 异步读取
buffer.write(content)
return {
"code": 200,
"message": "图片上传成功", # 增加 message 字段
"data": {
"filename": image.filename, # 使用 image.filename
"url": f"/upload/images/{image.filename}" # 使用 image.filename
}
}
except Exception as e:
print(f"图片上传失败: {str(e)}") # 输出详细错误
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, # 恢复 status.HTTP_500_INTERNAL_SERVER_ERROR
detail=f"图片上传失败: {str(e)}"
)
# 挂载静态文件目录
app.mount("/upload", StaticFiles(directory="upload"), name="upload")
@app.get("/")
def greet_json():
return {"Hello": "World!"}
@app.get("/token")
async def wechat_token_check(
signature: str = Query(..., description="微信加密签名"),
timestamp: str = Query(..., description="时间戳"),
nonce: str = Query(..., description="随机数"),
echostr: str = Query(..., description="验证字符串"),
):
# 1. 参数排序与拼接
tmp_list = sorted([WECHAT_TOKEN, timestamp, nonce])
tmp_str = "".join(tmp_list)
# 2. SHA1加密生成签名
sha1 = hashlib.sha1()
sha1.update(tmp_str.encode("utf-8"))
hashcode = sha1.hexdigest()
# 3. 调试日志(可选)
# print(f"token: {WECHAT_TOKEN}, timestamp: {timestamp}, nonce: {nonce}")
# print(f"生成签名: {hashcode}, 微信签名: {signature}")
import time
current_time = int(time.time())
if abs(current_time - int(timestamp)) > 300: # 5分钟=300秒
raise HTTPException(401, "时间戳过期")
# 4. 校验签名并返回结果
if hashcode == signature:
print("微信服务器验证成功: echostr",echostr)
return echostr # 校验成功,返回echostr
else:
# 校验失败返回401错误[4](@ref)[5](@ref)
raise HTTPException(
status_code=401,
detail="签名验证失败",
headers={"WWW-Authenticate": "Bearer"}
)
app.include_router(api_router)