Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import datetime | |
| import jwt | |
| from fastapi import FastAPI, Depends, HTTPException | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import Optional | |
| from atproto import Client | |
| from database import get_db | |
| app = FastAPI(title="whiteSNS API") | |
| JWT_SECRET = os.getenv("JWT_SECRET", "super-secret-default-key-please-change") | |
| JWT_ALGORITHM = "HS256" | |
| security = HTTPBearer() | |
| security_optional = HTTPBearer(auto_error=False) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # --- Pydantic Models --- | |
| class LoginRequest(BaseModel): | |
| handle: str | |
| app_password: str | |
| class PostCreate(BaseModel): | |
| text: str | |
| parent_id: Optional[int] = None | |
| class ProfileUpdate(BaseModel): | |
| display_name: str | |
| avatar: str | |
| bio: str | |
| banner: Optional[str] = "" | |
| class PostAction(BaseModel): | |
| post_id: int | |
| class RepostRequest(BaseModel): | |
| post_id: int | |
| quote_text: Optional[str] = None | |
| class FollowAction(BaseModel): | |
| target_handle: str | |
| # --- Auth Dependencies --- | |
| def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)): | |
| try: | |
| payload = jwt.decode(credentials.credentials, JWT_SECRET, algorithms=[JWT_ALGORITHM]) | |
| return payload | |
| except jwt.ExpiredSignatureError: | |
| raise HTTPException(status_code=401, detail="Token has expired") | |
| except jwt.InvalidTokenError: | |
| raise HTTPException(status_code=401, detail="Invalid token") | |
| def optional_verify_token(credentials: Optional[HTTPAuthorizationCredentials] = Depends(security_optional)): | |
| if not credentials: | |
| return None | |
| try: | |
| return jwt.decode(credentials.credentials, JWT_SECRET, algorithms=[JWT_ALGORITHM]) | |
| except Exception: | |
| return None | |
| # --- Endpoints --- | |
| def read_root(): | |
| return {"message": "whiteSNS API is running"} | |
| def health_check(): | |
| return {"status": "ok"} | |
| def login_bluesky(req: LoginRequest): | |
| try: | |
| client = Client() | |
| profile = client.login(req.handle, req.app_password) | |
| user_handle = profile.handle | |
| display_name = getattr(profile, 'display_name', profile.handle) or profile.handle | |
| avatar = getattr(profile, 'avatar', '') or '' | |
| db = get_db() | |
| if db: | |
| res = db.table('users').select('*').eq('handle', user_handle).execute() | |
| if not res.data: | |
| db.table('users').insert({ | |
| 'handle': user_handle, | |
| 'display_name': display_name, | |
| 'avatar': avatar, | |
| 'bio': 'Blueskyから連携しました。', | |
| 'banner': '' | |
| }).execute() | |
| expiration = datetime.datetime.utcnow() + datetime.timedelta(days=7) | |
| payload = {"sub": user_handle, "exp": expiration} | |
| token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM) | |
| return {"success": True, "token": token, "handle": user_handle, "did": profile.did} | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return {"success": False, "error": f"認証または初期化エラー: {str(e)}"} | |
| def get_profile(handle: str): | |
| try: | |
| db = get_db() | |
| if not db: | |
| raise HTTPException(status_code=500, detail="Database connection error") | |
| res = db.table('users').select('*').eq('handle', handle).execute() | |
| if not res.data: | |
| return {"handle": handle, "display_name": handle, "avatar": "", "bio": "", "banner": ""} | |
| return res.data[0] | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def update_profile(data: ProfileUpdate, token_data: dict = Depends(verify_token)): | |
| try: | |
| handle = token_data.get("sub") | |
| db = get_db() | |
| update_data = { | |
| 'display_name': data.display_name, | |
| 'avatar': data.avatar, | |
| 'bio': data.bio, | |
| 'banner': data.banner or '' | |
| } | |
| res = db.table('users').update(update_data).eq('handle', handle).execute() | |
| return {"success": True, "data": res.data[0] if res.data else None} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def get_posts(token_data: dict = Depends(optional_verify_token)): | |
| try: | |
| db = get_db() | |
| if not db: | |
| return [] | |
| res = db.table('posts').select('*, users!posts_handle_fkey(display_name, avatar)').order('created_at', desc=True).limit(50).execute() | |
| formatted = [] | |
| my_handle = token_data.get("sub") if token_data else None | |
| my_likes = [] | |
| if my_handle: | |
| likes_res = db.table('likes').select('post_id').eq('handle', my_handle).execute() | |
| my_likes = [l['post_id'] for l in likes_res.data] | |
| for p in res.data: | |
| users_data = p.get("users") or {} | |
| like_res = db.table('likes').select('post_id', count='exact').eq('post_id', p['id']).execute() | |
| reply_res = db.table('posts').select('id', count='exact').eq('parent_id', p['id']).execute() | |
| formatted.append({ | |
| "id": p["id"], | |
| "user": users_data.get("display_name", p["handle"]), | |
| "handle": p["handle"], | |
| "text": p["text"], | |
| "avatar": users_data.get("avatar", ""), | |
| "time": p["created_at"], | |
| "parent_id": p.get("parent_id"), | |
| "repost_id": p.get("repost_id"), | |
| "likes": like_res.count or 0, | |
| "replies": reply_res.count or 0, | |
| "is_liked": p['id'] in my_likes | |
| }) | |
| return formatted | |
| except Exception as e: | |
| print(f"Get posts error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def get_post_detail(post_id: int, token_data: dict = Depends(optional_verify_token)): | |
| try: | |
| db = get_db() | |
| # Fetch post | |
| res = db.table('posts').select('*, users!posts_handle_fkey(display_name, avatar)').eq('id', post_id).execute() | |
| if not res.data: | |
| raise HTTPException(status_code=404, detail="Post not found") | |
| p = res.data[0] | |
| users_data = p.get("users") or {} | |
| # Like status | |
| my_handle = token_data.get("sub") if token_data else None | |
| is_liked = False | |
| if my_handle: | |
| chk = db.table('likes').select('post_id').eq('handle', my_handle).eq('post_id', post_id).execute() | |
| is_liked = len(chk.data) > 0 | |
| like_res = db.table('likes').select('post_id', count='exact').eq('post_id', post_id).execute() | |
| # Replies | |
| replies_res = db.table('posts').select('*, users!posts_handle_fkey(display_name, avatar)').eq('parent_id', post_id).order('created_at', desc=False).execute() | |
| replies = [] | |
| for r in replies_res.data: | |
| ru = r.get("users") or {} | |
| replies.append({ | |
| "id": r["id"], "user": ru.get("display_name", r["handle"]), | |
| "handle": r["handle"], "text": r["text"], | |
| "avatar": ru.get("avatar", ""), "time": r["created_at"], | |
| "likes": 0, # Optimization: could fetch these too | |
| "replies": 0 | |
| }) | |
| return { | |
| "post": { | |
| "id": p["id"], "user": users_data.get("display_name", p["handle"]), | |
| "handle": p["handle"], "text": p["text"], | |
| "avatar": users_data.get("avatar", ""), "time": p["created_at"], | |
| "likes": like_res.count or 0, "is_liked": is_liked, | |
| "repost_id": p.get("repost_id") | |
| }, | |
| "replies": replies | |
| } | |
| except HTTPException: raise | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def create_post(post: PostCreate, token_data: dict = Depends(verify_token)): | |
| try: | |
| handle = token_data.get("sub") | |
| db = get_db() | |
| insert_data = {'handle': handle, 'text': post.text} | |
| if post.parent_id: | |
| insert_data['parent_id'] = post.parent_id | |
| res = db.table('posts').insert(insert_data).execute() | |
| new_post = res.data[0] | |
| # Mention notifications | |
| mentions = re.findall(r'@([a-zA-Z0-9_.\-]+)', post.text) | |
| for mentioned in set(mentions): | |
| if mentioned.lower() == handle.lower(): | |
| continue | |
| user_check = db.table('users').select('handle').eq('handle', mentioned).execute() | |
| if user_check.data: | |
| db.table('notifications').insert({ | |
| 'handle': mentioned, 'type': 'mention', | |
| 'actor_handle': handle, 'post_id': new_post['id'] | |
| }).execute() | |
| # Reply notification | |
| if post.parent_id: | |
| parent_res = db.table('posts').select('handle').eq('id', post.parent_id).execute() | |
| if parent_res.data: | |
| parent_author = parent_res.data[0]['handle'] | |
| if parent_author != handle: | |
| db.table('notifications').insert({ | |
| 'handle': parent_author, 'type': 'reply', | |
| 'actor_handle': handle, 'post_id': new_post['id'] | |
| }).execute() | |
| return {"success": True, "post": new_post} | |
| except Exception as e: | |
| print(f"Create post error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def toggle_like(action: PostAction, token_data: dict = Depends(verify_token)): | |
| try: | |
| handle = token_data.get("sub") | |
| db = get_db() | |
| res = db.table('likes').select('*').eq('handle', handle).eq('post_id', action.post_id).execute() | |
| if res.data: | |
| db.table('likes').delete().eq('handle', handle).eq('post_id', action.post_id).execute() | |
| return {"action": "unliked"} | |
| else: | |
| db.table('likes').insert({'handle': handle, 'post_id': action.post_id}).execute() | |
| # Like notification | |
| post_res = db.table('posts').select('handle, text').eq('id', action.post_id).execute() | |
| if post_res.data: | |
| author = post_res.data[0]['handle'] | |
| original_text = post_res.data[0]['text'] | |
| if author != handle: | |
| db.table('notifications').insert({ | |
| 'handle': author, 'type': 'like', | |
| 'actor_handle': handle, 'post_id': action.post_id, | |
| 'post_text': (original_text[:100] + '...') if len(original_text) > 100 else original_text | |
| }).execute() | |
| return {"action": "liked"} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def repost(req: RepostRequest, token_data: dict = Depends(verify_token)): | |
| try: | |
| handle = token_data.get("sub") | |
| db = get_db() | |
| # Check if original post exists | |
| orig = db.table('posts').select('handle, text').eq('id', req.post_id).execute() | |
| if not orig.data: | |
| raise HTTPException(status_code=404, detail="Original post not found") | |
| insert_data = { | |
| 'handle': handle, | |
| 'text': req.quote_text if req.quote_text else "", | |
| 'repost_id': req.post_id | |
| } | |
| res = db.table('posts').insert(insert_data).execute() | |
| # Notify | |
| author = orig.data[0]['handle'] | |
| if author != handle: | |
| db.table('notifications').insert({ | |
| 'handle': author, 'type': 'repost', | |
| 'actor_handle': handle, 'post_id': res.data[0]['id'] | |
| }).execute() | |
| return {"success": True, "post": res.data[0]} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def pin_post(action: PostAction, token_data: dict = Depends(verify_token)): | |
| try: | |
| handle = token_data.get("sub") | |
| db = get_db() | |
| # Verify ownership | |
| chk = db.table('posts').select('id').eq('id', action.post_id).eq('handle', handle).execute() | |
| if not chk.data: | |
| raise HTTPException(status_code=403, detail="Not your post") | |
| # Update user pinned_post_id | |
| db.table('users').update({'pinned_post_id': action.post_id}).eq('handle', handle).execute() | |
| return {"success": True} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def unpin_post(token_data: dict = Depends(verify_token)): | |
| try: | |
| handle = token_data.get("sub") | |
| db = get_db() | |
| db.table('users').update({'pinned_post_id': None}).eq('handle', handle).execute() | |
| return {"success": True} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def toggle_bookmark(action: PostAction, token_data: dict = Depends(verify_token)): | |
| try: | |
| handle = token_data.get("sub") | |
| db = get_db() | |
| res = db.table('bookmarks').select('*').eq('handle', handle).eq('post_id', action.post_id).execute() | |
| if res.data: | |
| db.table('bookmarks').delete().eq('handle', handle).eq('post_id', action.post_id).execute() | |
| return {"action": "removed"} | |
| else: | |
| db.table('bookmarks').insert({'handle': handle, 'post_id': action.post_id}).execute() | |
| return {"action": "added"} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def get_liked_posts(handle: str): | |
| try: | |
| db = get_db() | |
| likes_res = db.table('likes').select('post_id').eq('handle', handle).order('created_at', desc=True).execute() | |
| if not likes_res.data: | |
| return [] | |
| posts = [] | |
| for like in likes_res.data: | |
| pid = like['post_id'] | |
| p_res = db.table('posts').select('*, users!posts_handle_fkey(display_name, avatar)').eq('id', pid).execute() | |
| if p_res.data: | |
| p = p_res.data[0] | |
| ud = p.get("users") or {} | |
| posts.append({ | |
| "id": p["id"], "user": ud.get("display_name", p["handle"]), | |
| "handle": p["handle"], "text": p["text"], | |
| "avatar": ud.get("avatar", ""), "time": p["created_at"], | |
| "parent_id": p.get("parent_id"), "likes": 0, "replies": 0, | |
| }) | |
| return posts | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def get_follow_info(handle: str, token_data: dict = Depends(optional_verify_token)): | |
| try: | |
| db = get_db() | |
| followers_res = db.table('follows').select('follower_handle', count='exact').eq('following_handle', handle).execute() | |
| following_res = db.table('follows').select('following_handle', count='exact').eq('follower_handle', handle).execute() | |
| is_following = False | |
| if token_data: | |
| me = token_data.get('sub') | |
| check = db.table('follows').select('*').eq('follower_handle', me).eq('following_handle', handle).execute() | |
| is_following = len(check.data) > 0 | |
| return { | |
| "followers": followers_res.count or 0, | |
| "following": following_res.count or 0, | |
| "is_following": is_following | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def get_followers(handle: str): | |
| try: | |
| db = get_db() | |
| res = db.table('follows').select('follower_handle').eq('following_handle', handle).execute() | |
| handles = [r['follower_handle'] for r in res.data] | |
| if not handles: return [] | |
| users = db.table('users').select('*').in_('handle', handles).execute() | |
| return users.data | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def get_following(handle: str): | |
| try: | |
| db = get_db() | |
| res = db.table('follows').select('following_handle').eq('follower_handle', handle).execute() | |
| handles = [r['following_handle'] for r in res.data] | |
| if not handles: return [] | |
| users = db.table('users').select('*').in_('handle', handles).execute() | |
| return users.data | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def toggle_follow(action: FollowAction, token_data: dict = Depends(verify_token)): | |
| try: | |
| handle = token_data.get("sub") | |
| target = action.target_handle | |
| if handle == target: | |
| raise HTTPException(status_code=400, detail="自分自身はフォローできません") | |
| db = get_db() | |
| check = db.table('follows').select('*').eq('follower_handle', handle).eq('following_handle', target).execute() | |
| if check.data: | |
| db.table('follows').delete().eq('follower_handle', handle).eq('following_handle', target).execute() | |
| return {"action": "unfollowed"} | |
| else: | |
| db.table('follows').insert({'follower_handle': handle, 'following_handle': target}).execute() | |
| db.table('notifications').insert({ | |
| 'handle': target, 'type': 'follow', | |
| 'actor_handle': handle, 'post_id': None | |
| }).execute() | |
| return {"action": "followed"} | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def get_notifications(token_data: dict = Depends(verify_token)): | |
| try: | |
| handle = token_data.get("sub") | |
| db = get_db() | |
| res = db.table('notifications').select('*').eq('handle', handle).order('created_at', desc=True).limit(30).execute() | |
| db.table('notifications').update({'is_read': True}).eq('handle', handle).eq('is_read', False).execute() | |
| return res.data | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def get_unread_count(token_data: dict = Depends(verify_token)): | |
| try: | |
| handle = token_data.get("sub") | |
| db = get_db() | |
| res = db.table('notifications').select('id', count='exact').eq('handle', handle).eq('is_read', False).execute() | |
| return {"count": res.count or 0} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |