ComfyUI-Ranking-API / router_posts.py
ZHIWEI666's picture
Upload 4 files
c70d6e7 verified
# 云端Space代码/router_posts.py
# ==========================================
# 💬 讨论区API路由(小红书风格图文社区)
# ==========================================
# 功能:帖子发布、列表、详情、互动(点赞/收藏/评论/打赏)
# ==========================================
from fastapi import APIRouter, HTTPException, Depends
from sqlalchemy.orm import Session
from models import PostCreate, PostUpdate, RatingRequest
import 数据库连接 as db
from 安全认证 import require_auth
from db_utils import record_view, sort_cache
from database_sql import get_db
from models_sql import Wallet, Transaction
from notifications import add_notification
import time
import uuid
import hashlib
import datetime
import logging
router = APIRouter()
# ==========================================
# 📝 帖子CRUD接口
# ==========================================
@router.get("/api/posts")
async def get_posts(page: int = 1, limit: int = 20, sort: str = "latest"):
"""
获取帖子列表(分页,支持多种排序方式)
- sort=latest: 按创建时间降序(默认)
- sort=likes: 按点赞数降序
- sort=favorites: 按收藏数降序
- sort=views: 按总访问量降序
- sort=daily_views: 按日访问量降序
- sort=tips: 按打赏总额降序
"""
posts_db = db.load_data("posts.json", default_data=[])
users_db = db.load_data("users.json", default_data={})
# users_db 已经是 {account: user_info} 格式,直接使用
user_map = users_db
# 🗂️ 使用排序缓存优化排序性能
cache_key = f"posts:{sort}"
def sort_fn(data):
if sort == "likes":
data.sort(key=lambda x: x.get("likes", 0), reverse=True)
elif sort == "favorites":
data.sort(key=lambda x: x.get("favorites", 0), reverse=True)
elif sort == "views":
data.sort(key=lambda x: x.get("views", 0), reverse=True)
elif sort == "daily_views":
data.sort(key=lambda x: x.get("daily_views", 0), reverse=True)
elif sort == "tips":
data.sort(key=lambda x: sum(t.get("amount", 0) for t in x.get("tip_board", [])), reverse=True)
elif sort == "rating": # ⭐ 按评分排序
data.sort(key=lambda x: (x.get("rating_avg", 0), x.get("rating_count", 0)), reverse=True)
else: # latest 或其他默认
data.sort(key=lambda x: x.get("created_at", 0), reverse=True)
sorted_posts = sort_cache.get_sorted(cache_key, posts_db, sort_fn)
# 分页
start = (page - 1) * limit
end = start + limit
paged_posts = sorted_posts[start:end]
# 附加作者信息,并过滤敏感字段
result = []
for post in paged_posts:
author_info = user_map.get(post.get("author"), {})
post_data = {
**post,
"author_name": author_info.get("name", post.get("author")),
"author_avatar": author_info.get("avatarDataUrl", "")
}
# 过滤敏感字段(列表接口仅过滤 viewed_by,保留 liked_by 和 favorited_by)
post_data.pop("viewed_by", None)
result.append(post_data)
return {
"status": "success",
"data": result,
"total": len(posts_db),
"page": page,
"limit": limit
}
@router.get("/api/my-posts")
async def get_my_posts(current_user: str = Depends(require_auth)):
"""
获取我的帖子列表
"""
posts_db = db.load_data("posts.json", default_data=[])
# 筛选当前用户的帖子
my_posts = [p for p in posts_db if p.get("author") == current_user]
# 按创建时间倒序
my_posts = sorted(my_posts, key=lambda x: x.get("created_at", 0), reverse=True)
# 过滤敏感字段(列表接口仅过滤 viewed_by,保留 liked_by 和 favorited_by)
result = []
for post in my_posts:
post_data = dict(post)
post_data.pop("viewed_by", None)
result.append(post_data)
return {
"status": "success",
"data": result
}
@router.get("/api/posts/{post_id}")
async def get_post_detail(post_id: str):
"""
获取帖子详情
"""
posts_db = db.load_data("posts.json", default_data=[])
users_db = db.load_data("users.json", default_data={})
# users_db 已经是 {account: user_info} 格式,直接使用
user_map = users_db
for post in posts_db:
if post["id"] == post_id:
author_info = user_map.get(post.get("author"), {})
post_data = {
**post,
"author_name": author_info.get("name", post.get("author")),
"author_avatar": author_info.get("avatarDataUrl", "")
}
# 过滤敏感字段
post_data.pop("viewed_by", None)
return {
"status": "success",
"data": post_data
}
raise HTTPException(status_code=404, detail="帖子不存在")
@router.post("/api/posts")
async def create_post(post: PostCreate, current_user: str = Depends(require_auth)):
"""
发布帖子
"""
posts_db = db.load_data("posts.json", default_data=[])
# 限制图片数量
images = (post.images or [])[:9]
post_type = post.post_type if hasattr(post, "post_type") else "image"
video_url = post.video_url if hasattr(post, "video_url") else None
# 视频帖:images 置为空列表
if post_type == "video":
images = []
new_post = {
"id": f"post_{int(time.time())}_{uuid.uuid4().hex[:6]}",
"title": post.title,
"content": post.content,
"cover_image": post.cover_image,
"images": images,
"author": current_user,
"created_at": int(time.time()),
"is_original": post.is_original if post.is_original is not None else False, # 🎨 原创作品标记
"post_type": post_type,
"video_url": video_url,
# 互动数据
"likes": 0,
"favorites": 0,
"comments": 0,
"liked_by": [],
"favorited_by": [],
"tip_board": [], # 打赏榜单
"views": 0,
"daily_views": 0,
"viewed_by": [],
"daily_views_date": "",
"rating_avg": 0.0,
"rating_count": 0,
"rating_dist": {"1": 0, "2": 0, "3": 0, "4": 0, "5": 0},
"rated_by": {}
}
posts_db.insert(0, new_post)
db.save_data("posts.json", posts_db)
# 🗂️ 清除排序缓存
sort_cache.invalidate("posts:")
return {"status": "success", "data": new_post}
@router.put("/api/posts/{post_id}")
async def update_post(post_id: str, update_data: PostUpdate, current_user: str = Depends(require_auth)):
"""
更新帖子(仅作者可操作,原子操作,并发安全)
"""
result_holder = {}
def updater(data):
for post in data:
if post["id"] == post_id:
if post.get("author") != current_user:
result_holder["error"] = "forbidden"
return False
# 更新字段
if update_data.title is not None:
post["title"] = update_data.title
if update_data.content is not None:
post["content"] = update_data.content
if update_data.cover_image is not None:
post["cover_image"] = update_data.cover_image
if update_data.images is not None:
post["images"] = update_data.images[:9]
if update_data.is_original is not None:
post["is_original"] = update_data.is_original # 🎨 更新原创作品标记
if update_data.post_type is not None:
post["post_type"] = update_data.post_type
# 如果切换为视频帖,清空图片列表
if update_data.post_type == "video":
post["images"] = []
if update_data.video_url is not None:
post["video_url"] = update_data.video_url
result_holder["post"] = post
return True
result_holder["error"] = "not_found"
return False
db.atomic_update("posts.json", updater, default_data=[])
if result_holder.get("error") == "forbidden":
raise HTTPException(status_code=403, detail="无权修改他人帖子")
if result_holder.get("error") == "not_found":
raise HTTPException(status_code=404, detail="帖子不存在")
# 🗂️ 清除排序缓存
sort_cache.invalidate("posts:")
return {"status": "success"}
@router.delete("/api/posts/{post_id}")
async def delete_post(post_id: str, current_user: str = Depends(require_auth)):
"""
删除帖子(仅作者可操作)
"""
posts_db = db.load_data("posts.json", default_data=[])
for i, post in enumerate(posts_db):
if post["id"] == post_id:
if post.get("author") != current_user:
raise HTTPException(status_code=403, detail="无权删除他人帖子")
posts_db.pop(i)
db.save_data("posts.json", posts_db)
# 🗂️ 清除排序缓存
sort_cache.invalidate("posts:")
return {"status": "success"}
raise HTTPException(status_code=404, detail="帖子不存在")
# ==========================================
# ❤️ 互动接口(点赞/收藏)
# ==========================================
@router.post("/api/posts/{post_id}/rating")
async def rate_post(post_id: str, request: RatingRequest, current_user: str = Depends(require_auth)):
"""
为帖子评分(原子操作,并发安全)
⭐ score: 1-5
"""
score = request.score
if score < 1 or score > 5:
raise HTTPException(status_code=400, detail="评分必须在1-5之间")
result_container = [None]
def updater(data):
for post in data:
if post["id"] == post_id:
# 禁止自评
if post.get("author") == current_user:
result_container[0] = {"error": "self_rating"}
return
# 初始化评分字段
if "rating_avg" not in post:
post["rating_avg"] = 0.0
if "rating_count" not in post:
post["rating_count"] = 0
if "rating_dist" not in post:
post["rating_dist"] = {"1": 0, "2": 0, "3": 0, "4": 0, "5": 0}
if "rated_by" not in post:
post["rated_by"] = {}
rated_by = post["rated_by"]
rating_dist = post["rating_dist"]
old_score = None
if current_user in rated_by:
old_score = rated_by[current_user]["score"]
if old_score is not None:
# 已评分,先减去旧分数分布
rating_dist[str(old_score)] = max(0, rating_dist.get(str(old_score), 0) - 1)
rating_dist[str(score)] = rating_dist.get(str(score), 0) + 1
else:
# 未评分,增加计数
post["rating_count"] = post.get("rating_count", 0) + 1
rating_dist[str(score)] = rating_dist.get(str(score), 0) + 1
rated_by[current_user] = {"score": score, "time": int(time.time())}
# 重新计算平均分
total = sum(int(k) * v for k, v in rating_dist.items())
count = post["rating_count"]
post["rating_avg"] = round(total / count, 2) if count > 0 else 0.0
result_container[0] = {
"status": "success",
"rating_avg": post["rating_avg"],
"rating_count": post["rating_count"],
"rating_dist": post["rating_dist"],
"user_score": score
}
return
result_container[0] = None # 未找到帖子
db.atomic_update("posts.json", updater, default_data=[])
if result_container[0] is None:
raise HTTPException(status_code=404, detail="帖子不存在")
if result_container[0].get("error") == "self_rating":
raise HTTPException(status_code=400, detail="不能给自己发布的帖子评分")
# 🗂️ 清除排序缓存(评分变化可能影响排序)
sort_cache.invalidate("posts:")
return result_container[0]
@router.post("/api/posts/{post_id}/like")
async def toggle_like(post_id: str, current_user: str = Depends(require_auth)):
"""
点赞/取消点赞(原子操作,并发安全)
"""
result_container = [None]
def updater(data):
for post in data:
if post["id"] == post_id:
liked_by = post.get("liked_by", [])
if current_user in liked_by:
liked_by.remove(current_user)
post["likes"] = max(0, post.get("likes", 0) - 1)
action = "unliked"
else:
liked_by.append(current_user)
post["likes"] = post.get("likes", 0) + 1
action = "liked"
post["liked_by"] = liked_by
result_container[0] = {"status": "success", "action": action, "likes": post["likes"]}
return
result_container[0] = None # 未找到帖子
db.atomic_update("posts.json", updater, default_data=[])
if result_container[0] is None:
raise HTTPException(status_code=404, detail="帖子不存在")
# 🗂️ 清除排序缓存(点赞数变化可能影响排序)
sort_cache.invalidate("posts:")
return result_container[0]
@router.post("/api/posts/{post_id}/favorite")
async def toggle_favorite(post_id: str, current_user: str = Depends(require_auth)):
"""
收藏/取消收藏(原子操作,并发安全)
"""
result_container = [None]
def updater(data):
for post in data:
if post["id"] == post_id:
favorited_by = post.get("favorited_by", [])
if current_user in favorited_by:
favorited_by.remove(current_user)
post["favorites"] = max(0, post.get("favorites", 0) - 1)
action = "unfavorited"
else:
favorited_by.append(current_user)
post["favorites"] = post.get("favorites", 0) + 1
action = "favorited"
post["favorited_by"] = favorited_by
result_container[0] = {"status": "success", "action": action, "favorites": post["favorites"]}
return
result_container[0] = None # 未找到帖子
db.atomic_update("posts.json", updater, default_data=[])
if result_container[0] is None:
raise HTTPException(status_code=404, detail="帖子不存在")
# 🗂️ 清除排序缓存(收藏数变化可能影响排序)
sort_cache.invalidate("posts:")
return result_container[0]
# ==========================================
# 🎁 打赏接口
# ==========================================
# 📝 审计日志
logger = logging.getLogger("ComfyUI-Ranking.Posts")
def calculate_tx_hash(tx_id, account, tx_type, amount, prev_hash):
data = f"{tx_id}{account}{tx_type}{amount}{prev_hash}"
return hashlib.sha256(data.encode()).hexdigest()
@router.post("/api/posts/{post_id}/tip")
async def tip_post(post_id: str, amount: int, is_anon: bool = False, current_user: str = Depends(require_auth), db_session: Session = Depends(get_db)):
"""
打赏帖子(原子操作,并发安全)- 使用 SQL Wallet 系统
"""
if amount <= 0:
raise HTTPException(status_code=400, detail="打赏金额必须大于0")
result_container = [None]
author_account = [None] # 用于在原子操作外获取作者账号
def updater(data):
# 在锁内查找帖子
target_post = None
for post in data:
if post["id"] == post_id:
target_post = post
break
if not target_post:
result_container[0] = {"error": "not_found"}
return
# 不能打赏自己
if target_post.get("author") == current_user:
result_container[0] = {"error": "self_tip"}
return
author_account[0] = target_post.get("author")
# 更新打赏榜单
tip_board = target_post.get("tip_board", [])
existing = next((t for t in tip_board if t["account"] == current_user), None)
if existing:
existing["amount"] += amount
else:
tip_board.append({"account": current_user, "amount": amount, "is_anon": is_anon})
tip_board.sort(key=lambda x: x["amount"], reverse=True)
target_post["tip_board"] = tip_board
result_container[0] = {"status": "success", "message": f"成功打赏 {amount} 积分"}
db.atomic_update("posts.json", updater, default_data=[])
result = result_container[0]
if result is None or result.get("error") == "not_found":
raise HTTPException(status_code=404, detail="帖子不存在")
if result.get("error") == "self_tip":
raise HTTPException(status_code=400, detail="不能打赏自己的帖子")
# 💳 使用 SQL Wallet 系统处理余额转账
try:
author = author_account[0]
if not author:
raise HTTPException(status_code=404, detail="作者账户不存在")
# 🔒 P1幂等性防护:检查最近5秒内是否存在相同交易
recent_cutoff = datetime.datetime.utcnow() - datetime.timedelta(seconds=5)
duplicate_tx = db_session.query(Transaction).filter(
Transaction.account == current_user,
Transaction.tx_type == "TIP_OUT",
Transaction.amount == -amount,
Transaction.related_account == author,
Transaction.created_at >= recent_cutoff
).first()
if duplicate_tx:
return {"status": "success", "message": "打赏已处理(重复请求)"}
# 🔒 并发安全:使用悲观锁获取双方钱包
tipper_wallet = db_session.query(Wallet).filter(Wallet.account == current_user).with_for_update().first()
author_wallet = db_session.query(Wallet).filter(Wallet.account == author).with_for_update().first()
if not tipper_wallet or tipper_wallet.balance < amount:
raise HTTPException(status_code=400, detail="余额不足")
if not author_wallet:
# 如果作者钱包不存在,创建一个
author_wallet = Wallet(account=author, balance=0, earn_balance=0, tip_balance=0, frozen_balance=0)
db_session.add(author_wallet)
# 执行转账
tipper_wallet.balance -= amount
author_wallet.balance += amount # 实际收入进统一余额
author_wallet.tip_balance += amount # 累计打赏收益统计(只增不减)
# 创建交易记录
tx_id_tipper = f"TIP_OUT_{int(time.time())}_{uuid.uuid4().hex[:6]}"
tx_id_author = f"TIP_IN_{int(time.time())}_{uuid.uuid4().hex[:6]}"
# 获取最后交易记录的哈希
last_tx_tipper = db_session.query(Transaction).filter(Transaction.account == current_user).order_by(Transaction.created_at.desc()).first()
last_tx_author = db_session.query(Transaction).filter(Transaction.account == author).order_by(Transaction.created_at.desc()).first()
prev_hash_tipper = last_tx_tipper.tx_hash if last_tx_tipper else "GENESIS_HASH"
prev_hash_author = last_tx_author.tx_hash if last_tx_author else "GENESIS_HASH"
# 获取用户信息和帖子标题
users_db = db.load_data("users.json", default_data={})
posts_db = db.load_data("posts.json", default_data=[])
author_info = users_db.get(author, {})
tipper_info = users_db.get(current_user, {})
author_name = author_info.get("name", author)
tipper_name = tipper_info.get("name", current_user)
post_title = None
for post in posts_db:
if post["id"] == post_id:
post_title = post.get("title")
break
# 打赏方交易记录 (TIP_OUT)
tx_tipper = Transaction(
tx_id=tx_id_tipper,
account=current_user,
tx_type="TIP_OUT",
amount=-amount,
related_account=author,
item_id=post_id,
prev_hash=prev_hash_tipper,
tx_hash=calculate_tx_hash(tx_id_tipper, current_user, "TIP_OUT", -amount, prev_hash_tipper),
description=f"打赏给 {author_name}" + (f" 的帖子《{post_title}》" if post_title else ""),
item_title=post_title,
item_type="post",
related_user_name=author_name
)
# 接收方交易记录 (TIP_IN)
tx_author = Transaction(
tx_id=tx_id_author,
account=author,
tx_type="TIP_IN",
amount=amount,
related_account=current_user,
item_id=post_id,
prev_hash=prev_hash_author,
tx_hash=calculate_tx_hash(tx_id_author, author, "TIP_IN", amount, prev_hash_author),
description=f"收到 {tipper_name} 的帖子打赏" + (f" ({post_title})" if post_title else ""),
item_title=post_title,
item_type="post",
related_user_name=tipper_name if not is_anon else "匿名用户"
)
db_session.add(tx_tipper)
db_session.add(tx_author)
db_session.commit()
# 📝 审计日志
logger.info(f"POST_TIP | from={current_user} | to={author} | amount={amount} | post={post_id} | anon={is_anon}")
# 🔔 打赏通知(考虑匿名)
if not is_anon:
add_notification(author, {
"type": "tip",
"from_user": current_user,
"target_item_id": post_id,
"target_item_title": post_title or "",
"content": f"您收到来自 {tipper_name}{amount} 积分帖子打赏"
})
else:
add_notification(author, {
"type": "tip",
"from_user": "anonymous",
"target_item_id": post_id,
"target_item_title": "",
"content": f"您收到了一份 {amount} 积分的匿名帖子打赏"
})
# 🗂️ 清除排序缓存(打赏可能影响排序)
sort_cache.invalidate("posts:")
except HTTPException:
db_session.rollback()
raise
except Exception as e:
db_session.rollback()
logger.error(f"POST_TIP_ERROR | from={current_user} | post={post_id} | amount={amount} | error={str(e)}")
raise HTTPException(status_code=500, detail="打赏处理失败,请稍后重试")
return result
# ==========================================
# 💬 评论接口(复用通用评论系统)
# ==========================================
@router.get("/api/posts/{post_id}/comments")
async def get_post_comments(post_id: str):
"""
获取帖子评论
"""
comments_db = db.load_data("comments.json", default_data={})
users_db = db.load_data("users.json", default_data={})
# users_db 已经是 {account: user_info} 格式,直接使用
user_map = users_db
# comments_db 是 {item_id: [comments]} 格式
post_comments = comments_db.get(post_id, [])
# 附加用户信息
result = []
for c in post_comments:
author_info = user_map.get(c.get("author"), {})
result.append({
**c,
"author_name": author_info.get("name", c.get("author")),
"author_avatar": author_info.get("avatarDataUrl", "")
})
return {"status": "success", "data": result}
@router.post("/api/posts/{post_id}/comments")
async def add_post_comment(post_id: str, content: str, current_user: str = Depends(require_auth)):
"""
添加帖子评论
"""
if not content or not content.strip():
raise HTTPException(status_code=400, detail="评论内容不能为空")
posts_db = db.load_data("posts.json", default_data=[])
comments_db = db.load_data("comments.json", default_data={})
# 检查帖子是否存在
post_exists = any(p["id"] == post_id for p in posts_db)
if not post_exists:
raise HTTPException(status_code=404, detail="帖子不存在")
new_comment = {
"id": f"comment_{int(time.time())}_{uuid.uuid4().hex[:6]}",
"author": current_user,
"content": content.strip(),
"created_at": int(time.time())
}
# comments_db 是 {item_id: [comments]} 格式
post_comments = comments_db.get(post_id, [])
post_comments.insert(0, new_comment)
comments_db[post_id] = post_comments
db.save_data("comments.json", comments_db)
# 🗂️ 清除排序缓存(评论数变化可能影响排序)
sort_cache.invalidate("posts:")
# 更新帖子评论数
for post in posts_db:
if post["id"] == post_id:
post["comments"] = post.get("comments", 0) + 1
break
db.save_data("posts.json", posts_db)
return {"status": "success", "data": new_comment}
@router.post("/api/posts/{post_id}/view")
async def record_post_view(post_id: str, current_user: str = Depends(require_auth)):
"""
记录帖子访问量
👁️ 需要用户认证,每个用户只计算一次总访问量,日访问量每次调用都增加
"""
result = record_view("posts.json", post_id, current_user)
if result is None:
raise HTTPException(status_code=404, detail="帖子不存在")
# 🗂️ 清除排序缓存(浏览量变化可能影响排序)
sort_cache.invalidate("posts:")
return {"status": "success", "views": result["views"], "daily_views": result["daily_views"]}