whitesns-api / main.py
Nyanpre's picture
Update main.py
2863c18 verified
import os
import re
import datetime
import jwt
from fastapi import FastAPI, Depends, HTTPException
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional
from atproto import Client
from database import get_db
app = FastAPI(title="whiteSNS API")
JWT_SECRET = os.getenv("JWT_SECRET", "super-secret-default-key-please-change")
JWT_ALGORITHM = "HS256"
security = HTTPBearer()
security_optional = HTTPBearer(auto_error=False)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- Pydantic Models ---
class LoginRequest(BaseModel):
handle: str
app_password: str
class PostCreate(BaseModel):
text: str
parent_id: Optional[int] = None
class ProfileUpdate(BaseModel):
display_name: str
avatar: str
bio: str
banner: Optional[str] = ""
class PostAction(BaseModel):
post_id: int
class RepostRequest(BaseModel):
post_id: int
quote_text: Optional[str] = None
class FollowAction(BaseModel):
target_handle: str
# --- Auth Dependencies ---
def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
try:
payload = jwt.decode(credentials.credentials, JWT_SECRET, algorithms=[JWT_ALGORITHM])
return payload
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Token has expired")
except jwt.InvalidTokenError:
raise HTTPException(status_code=401, detail="Invalid token")
def optional_verify_token(credentials: Optional[HTTPAuthorizationCredentials] = Depends(security_optional)):
if not credentials:
return None
try:
return jwt.decode(credentials.credentials, JWT_SECRET, algorithms=[JWT_ALGORITHM])
except Exception:
return None
# --- Endpoints ---
@app.get("/")
def read_root():
return {"message": "whiteSNS API is running"}
@app.get("/api/health")
def health_check():
return {"status": "ok"}
@app.post("/api/login")
def login_bluesky(req: LoginRequest):
try:
client = Client()
profile = client.login(req.handle, req.app_password)
user_handle = profile.handle
display_name = getattr(profile, 'display_name', profile.handle) or profile.handle
avatar = getattr(profile, 'avatar', '') or ''
db = get_db()
if db:
res = db.table('users').select('*').eq('handle', user_handle).execute()
if not res.data:
db.table('users').insert({
'handle': user_handle,
'display_name': display_name,
'avatar': avatar,
'bio': 'Blueskyから連携しました。',
'banner': ''
}).execute()
expiration = datetime.datetime.utcnow() + datetime.timedelta(days=7)
payload = {"sub": user_handle, "exp": expiration}
token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
return {"success": True, "token": token, "handle": user_handle, "did": profile.did}
except Exception as e:
import traceback
traceback.print_exc()
return {"success": False, "error": f"認証または初期化エラー: {str(e)}"}
@app.get("/api/profile/{handle}")
def get_profile(handle: str):
try:
db = get_db()
if not db:
raise HTTPException(status_code=500, detail="Database connection error")
res = db.table('users').select('*').eq('handle', handle).execute()
if not res.data:
return {"handle": handle, "display_name": handle, "avatar": "", "bio": "", "banner": ""}
return res.data[0]
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.put("/api/profile")
def update_profile(data: ProfileUpdate, token_data: dict = Depends(verify_token)):
try:
handle = token_data.get("sub")
db = get_db()
update_data = {
'display_name': data.display_name,
'avatar': data.avatar,
'bio': data.bio,
'banner': data.banner or ''
}
res = db.table('users').update(update_data).eq('handle', handle).execute()
return {"success": True, "data": res.data[0] if res.data else None}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/posts")
def get_posts(token_data: dict = Depends(optional_verify_token)):
try:
db = get_db()
if not db:
return []
res = db.table('posts').select('*, users!posts_handle_fkey(display_name, avatar)').order('created_at', desc=True).limit(50).execute()
formatted = []
my_handle = token_data.get("sub") if token_data else None
my_likes = []
if my_handle:
likes_res = db.table('likes').select('post_id').eq('handle', my_handle).execute()
my_likes = [l['post_id'] for l in likes_res.data]
for p in res.data:
users_data = p.get("users") or {}
like_res = db.table('likes').select('post_id', count='exact').eq('post_id', p['id']).execute()
reply_res = db.table('posts').select('id', count='exact').eq('parent_id', p['id']).execute()
formatted.append({
"id": p["id"],
"user": users_data.get("display_name", p["handle"]),
"handle": p["handle"],
"text": p["text"],
"avatar": users_data.get("avatar", ""),
"time": p["created_at"],
"parent_id": p.get("parent_id"),
"repost_id": p.get("repost_id"),
"likes": like_res.count or 0,
"replies": reply_res.count or 0,
"is_liked": p['id'] in my_likes
})
return formatted
except Exception as e:
print(f"Get posts error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/posts/{post_id}")
def get_post_detail(post_id: int, token_data: dict = Depends(optional_verify_token)):
try:
db = get_db()
# Fetch post
res = db.table('posts').select('*, users!posts_handle_fkey(display_name, avatar)').eq('id', post_id).execute()
if not res.data:
raise HTTPException(status_code=404, detail="Post not found")
p = res.data[0]
users_data = p.get("users") or {}
# Like status
my_handle = token_data.get("sub") if token_data else None
is_liked = False
if my_handle:
chk = db.table('likes').select('post_id').eq('handle', my_handle).eq('post_id', post_id).execute()
is_liked = len(chk.data) > 0
like_res = db.table('likes').select('post_id', count='exact').eq('post_id', post_id).execute()
# Replies
replies_res = db.table('posts').select('*, users!posts_handle_fkey(display_name, avatar)').eq('parent_id', post_id).order('created_at', desc=False).execute()
replies = []
for r in replies_res.data:
ru = r.get("users") or {}
replies.append({
"id": r["id"], "user": ru.get("display_name", r["handle"]),
"handle": r["handle"], "text": r["text"],
"avatar": ru.get("avatar", ""), "time": r["created_at"],
"likes": 0, # Optimization: could fetch these too
"replies": 0
})
return {
"post": {
"id": p["id"], "user": users_data.get("display_name", p["handle"]),
"handle": p["handle"], "text": p["text"],
"avatar": users_data.get("avatar", ""), "time": p["created_at"],
"likes": like_res.count or 0, "is_liked": is_liked,
"repost_id": p.get("repost_id")
},
"replies": replies
}
except HTTPException: raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/posts")
def create_post(post: PostCreate, token_data: dict = Depends(verify_token)):
try:
handle = token_data.get("sub")
db = get_db()
insert_data = {'handle': handle, 'text': post.text}
if post.parent_id:
insert_data['parent_id'] = post.parent_id
res = db.table('posts').insert(insert_data).execute()
new_post = res.data[0]
# Mention notifications
mentions = re.findall(r'@([a-zA-Z0-9_.\-]+)', post.text)
for mentioned in set(mentions):
if mentioned.lower() == handle.lower():
continue
user_check = db.table('users').select('handle').eq('handle', mentioned).execute()
if user_check.data:
db.table('notifications').insert({
'handle': mentioned, 'type': 'mention',
'actor_handle': handle, 'post_id': new_post['id']
}).execute()
# Reply notification
if post.parent_id:
parent_res = db.table('posts').select('handle').eq('id', post.parent_id).execute()
if parent_res.data:
parent_author = parent_res.data[0]['handle']
if parent_author != handle:
db.table('notifications').insert({
'handle': parent_author, 'type': 'reply',
'actor_handle': handle, 'post_id': new_post['id']
}).execute()
return {"success": True, "post": new_post}
except Exception as e:
print(f"Create post error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/posts/like")
def toggle_like(action: PostAction, token_data: dict = Depends(verify_token)):
try:
handle = token_data.get("sub")
db = get_db()
res = db.table('likes').select('*').eq('handle', handle).eq('post_id', action.post_id).execute()
if res.data:
db.table('likes').delete().eq('handle', handle).eq('post_id', action.post_id).execute()
return {"action": "unliked"}
else:
db.table('likes').insert({'handle': handle, 'post_id': action.post_id}).execute()
# Like notification
post_res = db.table('posts').select('handle, text').eq('id', action.post_id).execute()
if post_res.data:
author = post_res.data[0]['handle']
original_text = post_res.data[0]['text']
if author != handle:
db.table('notifications').insert({
'handle': author, 'type': 'like',
'actor_handle': handle, 'post_id': action.post_id,
'post_text': (original_text[:100] + '...') if len(original_text) > 100 else original_text
}).execute()
return {"action": "liked"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/posts/repost")
def repost(req: RepostRequest, token_data: dict = Depends(verify_token)):
try:
handle = token_data.get("sub")
db = get_db()
# Check if original post exists
orig = db.table('posts').select('handle, text').eq('id', req.post_id).execute()
if not orig.data:
raise HTTPException(status_code=404, detail="Original post not found")
insert_data = {
'handle': handle,
'text': req.quote_text if req.quote_text else "",
'repost_id': req.post_id
}
res = db.table('posts').insert(insert_data).execute()
# Notify
author = orig.data[0]['handle']
if author != handle:
db.table('notifications').insert({
'handle': author, 'type': 'repost',
'actor_handle': handle, 'post_id': res.data[0]['id']
}).execute()
return {"success": True, "post": res.data[0]}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/posts/pin")
def pin_post(action: PostAction, token_data: dict = Depends(verify_token)):
try:
handle = token_data.get("sub")
db = get_db()
# Verify ownership
chk = db.table('posts').select('id').eq('id', action.post_id).eq('handle', handle).execute()
if not chk.data:
raise HTTPException(status_code=403, detail="Not your post")
# Update user pinned_post_id
db.table('users').update({'pinned_post_id': action.post_id}).eq('handle', handle).execute()
return {"success": True}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/posts/unpin")
def unpin_post(token_data: dict = Depends(verify_token)):
try:
handle = token_data.get("sub")
db = get_db()
db.table('users').update({'pinned_post_id': None}).eq('handle', handle).execute()
return {"success": True}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/posts/bookmark")
def toggle_bookmark(action: PostAction, token_data: dict = Depends(verify_token)):
try:
handle = token_data.get("sub")
db = get_db()
res = db.table('bookmarks').select('*').eq('handle', handle).eq('post_id', action.post_id).execute()
if res.data:
db.table('bookmarks').delete().eq('handle', handle).eq('post_id', action.post_id).execute()
return {"action": "removed"}
else:
db.table('bookmarks').insert({'handle': handle, 'post_id': action.post_id}).execute()
return {"action": "added"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/likes/{handle}")
def get_liked_posts(handle: str):
try:
db = get_db()
likes_res = db.table('likes').select('post_id').eq('handle', handle).order('created_at', desc=True).execute()
if not likes_res.data:
return []
posts = []
for like in likes_res.data:
pid = like['post_id']
p_res = db.table('posts').select('*, users!posts_handle_fkey(display_name, avatar)').eq('id', pid).execute()
if p_res.data:
p = p_res.data[0]
ud = p.get("users") or {}
posts.append({
"id": p["id"], "user": ud.get("display_name", p["handle"]),
"handle": p["handle"], "text": p["text"],
"avatar": ud.get("avatar", ""), "time": p["created_at"],
"parent_id": p.get("parent_id"), "likes": 0, "replies": 0,
})
return posts
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/follows/{handle}")
def get_follow_info(handle: str, token_data: dict = Depends(optional_verify_token)):
try:
db = get_db()
followers_res = db.table('follows').select('follower_handle', count='exact').eq('following_handle', handle).execute()
following_res = db.table('follows').select('following_handle', count='exact').eq('follower_handle', handle).execute()
is_following = False
if token_data:
me = token_data.get('sub')
check = db.table('follows').select('*').eq('follower_handle', me).eq('following_handle', handle).execute()
is_following = len(check.data) > 0
return {
"followers": followers_res.count or 0,
"following": following_res.count or 0,
"is_following": is_following
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/follows/{handle}/followers")
def get_followers(handle: str):
try:
db = get_db()
res = db.table('follows').select('follower_handle').eq('following_handle', handle).execute()
handles = [r['follower_handle'] for r in res.data]
if not handles: return []
users = db.table('users').select('*').in_('handle', handles).execute()
return users.data
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/follows/{handle}/following")
def get_following(handle: str):
try:
db = get_db()
res = db.table('follows').select('following_handle').eq('follower_handle', handle).execute()
handles = [r['following_handle'] for r in res.data]
if not handles: return []
users = db.table('users').select('*').in_('handle', handles).execute()
return users.data
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/follows")
def toggle_follow(action: FollowAction, token_data: dict = Depends(verify_token)):
try:
handle = token_data.get("sub")
target = action.target_handle
if handle == target:
raise HTTPException(status_code=400, detail="自分自身はフォローできません")
db = get_db()
check = db.table('follows').select('*').eq('follower_handle', handle).eq('following_handle', target).execute()
if check.data:
db.table('follows').delete().eq('follower_handle', handle).eq('following_handle', target).execute()
return {"action": "unfollowed"}
else:
db.table('follows').insert({'follower_handle': handle, 'following_handle': target}).execute()
db.table('notifications').insert({
'handle': target, 'type': 'follow',
'actor_handle': handle, 'post_id': None
}).execute()
return {"action": "followed"}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/notifications")
def get_notifications(token_data: dict = Depends(verify_token)):
try:
handle = token_data.get("sub")
db = get_db()
res = db.table('notifications').select('*').eq('handle', handle).order('created_at', desc=True).limit(30).execute()
db.table('notifications').update({'is_read': True}).eq('handle', handle).eq('is_read', False).execute()
return res.data
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/notifications/unread_count")
def get_unread_count(token_data: dict = Depends(verify_token)):
try:
handle = token_data.get("sub")
db = get_db()
res = db.table('notifications').select('id', count='exact').eq('handle', handle).eq('is_read', False).execute()
return {"count": res.count or 0}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))