|
|
|
|
|
from fastapi import FastAPI, Depends, HTTPException, status, Path, Request
|
|
|
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
|
from jose import JWTError, jwt
|
|
|
from passlib.context import CryptContext
|
|
|
from datetime import datetime, timedelta
|
|
|
from typing import Optional, List, Dict, Any
|
|
|
from pydantic import BaseModel, EmailStr, ValidationError
|
|
|
import os
|
|
|
from dotenv import load_dotenv
|
|
|
from beanie import Document, init_beanie
|
|
|
from motor.motor_asyncio import AsyncIOMotorClient
|
|
|
|
|
|
|
|
|
from mistralai import Mistral
|
|
|
import uvicorn
|
|
|
|
|
|
|
|
|
load_dotenv()
|
|
|
|
|
|
MONGODB_URL = os.getenv("MONGODB_URL", "mongodb://localhost:27017")
|
|
|
DATABASE_NAME = os.getenv("DATABASE_NAME", "mindmap_db")
|
|
|
SECRET_KEY = os.getenv("SECRET_KEY")
|
|
|
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
|
|
|
ALLOWED_ORIGINS = os.getenv("CORS_ORIGINS", "http://localhost:3000").split(",")
|
|
|
|
|
|
if not SECRET_KEY:
|
|
|
raise RuntimeError("SECRET_KEY environment variable must be set for security!")
|
|
|
|
|
|
ALGORITHM = "HS256"
|
|
|
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
|
|
|
|
|
|
|
|
app = FastAPI(title="MindMapX API")
|
|
|
|
|
|
|
|
|
app.add_middleware(
|
|
|
CORSMiddleware,
|
|
|
allow_origins=ALLOWED_ORIGINS,
|
|
|
allow_credentials=True,
|
|
|
allow_methods=["*"],
|
|
|
allow_headers=["*"],
|
|
|
)
|
|
|
|
|
|
|
|
|
try:
|
|
|
from slowapi import Limiter, _rate_limit_exceeded_handler
|
|
|
from slowapi.util import get_remote_address
|
|
|
limiter = Limiter(key_func=get_remote_address)
|
|
|
app.state.limiter = limiter
|
|
|
app.add_exception_handler(429, _rate_limit_exceeded_handler)
|
|
|
except ImportError:
|
|
|
limiter = None
|
|
|
|
|
|
|
|
|
@app.exception_handler(ValidationError)
|
|
|
async def validation_exception_handler(request: Request, exc: ValidationError):
|
|
|
return JSONResponse(
|
|
|
status_code=422,
|
|
|
content={"detail": [{"msg": str(err), "loc": err["loc"]} for err in exc.errors()]},
|
|
|
)
|
|
|
|
|
|
@app.exception_handler(HTTPException)
|
|
|
async def http_exception_handler(request: Request, exc: HTTPException):
|
|
|
return JSONResponse(
|
|
|
status_code=exc.status_code,
|
|
|
content={"detail": exc.detail},
|
|
|
)
|
|
|
|
|
|
|
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login")
|
|
|
|
|
|
|
|
|
client = Mistral(api_key=MISTRAL_API_KEY) if MISTRAL_API_KEY else None
|
|
|
|
|
|
|
|
|
class UserBase(BaseModel):
|
|
|
username: str
|
|
|
email: Optional[EmailStr] = None
|
|
|
|
|
|
class UserCreate(UserBase):
|
|
|
password: str
|
|
|
|
|
|
class User(Document, UserBase):
|
|
|
hashed_password: str
|
|
|
created_at: datetime
|
|
|
updated_at: datetime
|
|
|
|
|
|
class Settings:
|
|
|
name = "users"
|
|
|
indexes = ["username", "email"]
|
|
|
|
|
|
class Node(BaseModel):
|
|
|
id: str
|
|
|
data: Dict[str, Any]
|
|
|
position: Dict[str, float]
|
|
|
|
|
|
class Edge(BaseModel):
|
|
|
id: str
|
|
|
source: str
|
|
|
target: str
|
|
|
|
|
|
class MindMapBase(BaseModel):
|
|
|
name: str = "Untitled"
|
|
|
nodes: List[Node] = []
|
|
|
edges: List[Edge] = []
|
|
|
|
|
|
class MindMap(Document, MindMapBase):
|
|
|
user_id: str
|
|
|
created_at: datetime
|
|
|
updated_at: datetime
|
|
|
|
|
|
class Settings:
|
|
|
name = "mindmaps"
|
|
|
indexes = ["user_id", "created_at"]
|
|
|
|
|
|
|
|
|
class Token(BaseModel):
|
|
|
access_token: str
|
|
|
token_type: str
|
|
|
|
|
|
class TokenData(BaseModel):
|
|
|
username: Optional[str] = None
|
|
|
|
|
|
|
|
|
async def init_db():
|
|
|
client = AsyncIOMotorClient(MONGODB_URL)
|
|
|
await init_beanie(
|
|
|
database=client[DATABASE_NAME],
|
|
|
document_models=[User, MindMap],
|
|
|
)
|
|
|
|
|
|
|
|
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
|
|
return pwd_context.verify(plain_password, hashed_password)
|
|
|
|
|
|
def get_password_hash(password: str) -> str:
|
|
|
return pwd_context.hash(password)
|
|
|
|
|
|
async def get_user(username: str) -> Optional[User]:
|
|
|
return await User.find_one({"username": username})
|
|
|
|
|
|
async def authenticate_user(username: str, password: str) -> Optional[User]:
|
|
|
user = await get_user(username)
|
|
|
if not user:
|
|
|
return None
|
|
|
if not verify_password(password, user.hashed_password):
|
|
|
return None
|
|
|
return user
|
|
|
|
|
|
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
|
|
to_encode = data.copy()
|
|
|
if expires_delta:
|
|
|
expire = datetime.utcnow() + expires_delta
|
|
|
else:
|
|
|
expire = datetime.utcnow() + timedelta(minutes=15)
|
|
|
to_encode.update({"exp": expire})
|
|
|
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
|
|
return encoded_jwt
|
|
|
|
|
|
async def get_current_user(token: str = Depends(oauth2_scheme)) -> User:
|
|
|
credentials_exception = HTTPException(
|
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
detail="Could not validate credentials",
|
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
|
)
|
|
|
try:
|
|
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
|
username: str = payload.get("sub")
|
|
|
if username is None:
|
|
|
raise credentials_exception
|
|
|
token_data = TokenData(username=username)
|
|
|
except JWTError:
|
|
|
raise credentials_exception
|
|
|
user = await get_user(username=token_data.username)
|
|
|
if user is None:
|
|
|
raise credentials_exception
|
|
|
return user
|
|
|
|
|
|
|
|
|
@app.on_event("startup")
|
|
|
async def on_startup():
|
|
|
await init_db()
|
|
|
|
|
|
|
|
|
@app.post("/signup", response_model=Token)
|
|
|
async def signup(user: UserCreate):
|
|
|
existing_user = await User.find_one({"username": user.username})
|
|
|
if existing_user:
|
|
|
raise HTTPException(
|
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
detail="Username already registered",
|
|
|
)
|
|
|
if user.email:
|
|
|
existing_email = await User.find_one({"email": user.email})
|
|
|
if existing_email:
|
|
|
raise HTTPException(
|
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
detail="Email already registered",
|
|
|
)
|
|
|
hashed_password = get_password_hash(user.password)
|
|
|
new_user = User(
|
|
|
username=user.username,
|
|
|
email=user.email,
|
|
|
hashed_password=hashed_password,
|
|
|
created_at=datetime.utcnow(),
|
|
|
updated_at=datetime.utcnow(),
|
|
|
)
|
|
|
await new_user.insert()
|
|
|
access_token = create_access_token(
|
|
|
data={"sub": user.username},
|
|
|
expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES),
|
|
|
)
|
|
|
return {"access_token": access_token, "token_type": "bearer"}
|
|
|
|
|
|
@app.post("/login", response_model=Token)
|
|
|
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
|
|
user = await authenticate_user(form_data.username, form_data.password)
|
|
|
if not user:
|
|
|
raise HTTPException(
|
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
detail="Incorrect username or password",
|
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
|
)
|
|
|
access_token = create_access_token(
|
|
|
data={"sub": user.username},
|
|
|
expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES),
|
|
|
)
|
|
|
return {"access_token": access_token, "token_type": "bearer"}
|
|
|
|
|
|
@app.post("/logout")
|
|
|
async def logout(current_user: User = Depends(get_current_user)):
|
|
|
return {"message": "Successfully logged out"}
|
|
|
|
|
|
|
|
|
@app.get("/api/workflows", response_model=List[Dict[str, Any]])
|
|
|
async def get_workflows(current_user: User = Depends(get_current_user)):
|
|
|
mindmaps = await MindMap.find({"user_id": current_user.username}).to_list()
|
|
|
return [
|
|
|
{
|
|
|
"id": str(m.id),
|
|
|
"name": m.name,
|
|
|
"nodes": m.nodes,
|
|
|
"edges": m.edges,
|
|
|
"updatedAt": m.updated_at.isoformat() if m.updated_at else None,
|
|
|
"createdAt": m.created_at.isoformat() if m.created_at else None,
|
|
|
}
|
|
|
for m in mindmaps
|
|
|
]
|
|
|
|
|
|
@app.post("/api/workflows", response_model=Dict[str, Any])
|
|
|
async def create_workflow(
|
|
|
workflow: MindMapBase, current_user: User = Depends(get_current_user)
|
|
|
):
|
|
|
|
|
|
node_ids = [node.id for node in workflow.nodes]
|
|
|
if len(node_ids) != len(set(node_ids)):
|
|
|
raise HTTPException(status_code=400, detail="Duplicate node IDs are not allowed")
|
|
|
new_map = MindMap(
|
|
|
user_id=current_user.username,
|
|
|
name=workflow.name,
|
|
|
nodes=workflow.nodes,
|
|
|
edges=workflow.edges,
|
|
|
created_at=datetime.utcnow(),
|
|
|
updated_at=datetime.utcnow(),
|
|
|
)
|
|
|
await new_map.insert()
|
|
|
return {
|
|
|
"id": str(new_map.id),
|
|
|
"name": new_map.name,
|
|
|
"nodes": new_map.nodes,
|
|
|
"edges": new_map.edges,
|
|
|
"updatedAt": new_map.updated_at.isoformat() if new_map.updated_at else None,
|
|
|
"createdAt": new_map.created_at.isoformat() if new_map.created_at else None,
|
|
|
}
|
|
|
|
|
|
@app.get("/api/workflows/{workflow_id}", response_model=Dict[str, Any])
|
|
|
async def get_workflow(
|
|
|
workflow_id: str, current_user: User = Depends(get_current_user)
|
|
|
):
|
|
|
mind_map = await MindMap.get(workflow_id)
|
|
|
if not mind_map or mind_map.user_id != current_user.username:
|
|
|
raise HTTPException(status_code=404, detail="Workflow not found")
|
|
|
return {
|
|
|
"id": str(mind_map.id),
|
|
|
"name": mind_map.name,
|
|
|
"nodes": mind_map.nodes,
|
|
|
"edges": mind_map.edges,
|
|
|
"updatedAt": mind_map.updated_at.isoformat() if mind_map.updated_at else None,
|
|
|
"createdAt": mind_map.created_at.isoformat() if mind_map.created_at else None,
|
|
|
}
|
|
|
|
|
|
@app.put("/api/workflows/{workflow_id}", response_model=Dict[str, Any])
|
|
|
async def update_workflow(
|
|
|
workflow_id: str,
|
|
|
workflow: MindMapBase,
|
|
|
current_user: User = Depends(get_current_user),
|
|
|
):
|
|
|
mind_map = await MindMap.get(workflow_id)
|
|
|
if not mind_map or mind_map.user_id != current_user.username:
|
|
|
raise HTTPException(status_code=404, detail="Workflow not found")
|
|
|
|
|
|
node_ids = [node.id for node in workflow.nodes]
|
|
|
if len(node_ids) != len(set(node_ids)):
|
|
|
raise HTTPException(status_code=400, detail="Duplicate node IDs are not allowed")
|
|
|
mind_map.name = workflow.name
|
|
|
mind_map.nodes = workflow.nodes
|
|
|
mind_map.edges = workflow.edges
|
|
|
mind_map.updated_at = datetime.utcnow()
|
|
|
await mind_map.save()
|
|
|
return {
|
|
|
"id": str(mind_map.id),
|
|
|
"name": mind_map.name,
|
|
|
"nodes": mind_map.nodes,
|
|
|
"edges": mind_map.edges,
|
|
|
"updatedAt": mind_map.updated_at.isoformat() if mind_map.updated_at else None,
|
|
|
"createdAt": mind_map.created_at.isoformat() if mind_map.created_at else None,
|
|
|
}
|
|
|
|
|
|
@app.delete("/api/workflows/{workflow_id}")
|
|
|
async def delete_workflow(
|
|
|
workflow_id: str, current_user: User = Depends(get_current_user)
|
|
|
):
|
|
|
mind_map = await MindMap.get(workflow_id)
|
|
|
if not mind_map or mind_map.user_id != current_user.username:
|
|
|
raise HTTPException(status_code=404, detail="Workflow not found")
|
|
|
await mind_map.delete()
|
|
|
return {"message": "Workflow deleted successfully"}
|
|
|
|
|
|
|
|
|
class ChatMessage(BaseModel):
|
|
|
role: str
|
|
|
content: str
|
|
|
|
|
|
class ChatRequest(BaseModel):
|
|
|
messages: List[ChatMessage]
|
|
|
model: str = "mistral-large-latest"
|
|
|
stream: bool = False
|
|
|
safe_prompt: Optional[bool] = False
|
|
|
stop: Optional[List[str]] = None
|
|
|
|
|
|
@app.get("/models")
|
|
|
def list_models():
|
|
|
if not client:
|
|
|
raise HTTPException(status_code=500, detail="Mistral API key not configured.")
|
|
|
models = client.models.list()
|
|
|
return {"models": [model.id for model in models.data]}
|
|
|
|
|
|
@app.post("/chat")
|
|
|
async def chat(request: ChatRequest):
|
|
|
if not client:
|
|
|
raise HTTPException(status_code=500, detail="Mistral API key not configured.")
|
|
|
|
|
|
if request.stream:
|
|
|
def stream_response():
|
|
|
response = client.chat.stream(
|
|
|
model=request.model,
|
|
|
messages=[msg.dict() for msg in request.messages],
|
|
|
safe_prompt=request.safe_prompt,
|
|
|
stop=request.stop
|
|
|
)
|
|
|
for chunk in response:
|
|
|
content = chunk.data.choices[0].delta.content
|
|
|
if content:
|
|
|
yield content
|
|
|
return StreamingResponse(stream_response(), media_type="text/plain")
|
|
|
|
|
|
else:
|
|
|
response = client.chat.complete(
|
|
|
model=request.model,
|
|
|
messages=[msg.dict() for msg in request.messages],
|
|
|
safe_prompt=request.safe_prompt,
|
|
|
stop=request.stop
|
|
|
)
|
|
|
return JSONResponse({
|
|
|
"response": response.choices[0].message.content
|
|
|
})
|
|
|
|
|
|
|
|
|
@app.get("/")
|
|
|
async def root():
|
|
|
return {"message": "Mind Map API"}
|
|
|
|
|
|
@app.get("/health")
|
|
|
async def health():
|
|
|
return {"status": "ok"}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
import uvicorn
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|
|
|
|
|