Spaces:
Running
Running
| // internal/auth/handler/auth_handler.go | |
| package handler | |
| import ( | |
| "context" | |
| "encoding/json" | |
| "errors" | |
| "fmt" | |
| "io" | |
| "net/http" | |
| "net/url" | |
| "os" | |
| "time" | |
| "log" | |
| "server/internal/auth/models" | |
| "server/internal/auth/repository" | |
| "server/internal/auth/service" | |
| "server/internal/shared/crypto" | |
| "server/internal/shared/utils" | |
| "go.mongodb.org/mongo-driver/mongo" | |
| ) | |
| type LoggingResponseWriter struct { | |
| http.ResponseWriter | |
| StatusCode int | |
| } | |
| func (rw *LoggingResponseWriter) WriteHeader(statusCode int) { | |
| rw.StatusCode = statusCode | |
| rw.ResponseWriter.WriteHeader(statusCode) | |
| } | |
| func Login(w http.ResponseWriter, r *http.Request) { | |
| log.Println("Login handler triggered") | |
| // Generate secure random values | |
| state, _ := utils.SecureRandom(32) | |
| codeVerifier, _ := utils.SecureRandom(64) | |
| codeChallenge := crypto.SHA256Base64URL(codeVerifier) | |
| // Set cookies | |
| http.SetCookie(w, &http.Cookie{ | |
| Name: "oauth_state", | |
| Value: state, | |
| HttpOnly: true, | |
| Secure: true, | |
| SameSite: http.SameSiteNoneMode, | |
| Path: "/", | |
| MaxAge: 300, | |
| }) | |
| log.Println("Cookies set 1") | |
| http.SetCookie(w, &http.Cookie{ | |
| Name: "pkce_verifier", | |
| Value: codeVerifier, | |
| HttpOnly: true, | |
| Secure: true, | |
| SameSite: http.SameSiteNoneMode, | |
| Path: "/", | |
| MaxAge: 300, | |
| }) | |
| log.Println("Cookies set 2") | |
| // Construct the authorization URL for Keycloak | |
| authURL := fmt.Sprintf("%s/realms/%s/protocol/openid-connect/auth?client_id=%s&response_type=code&scope=openid profile email&redirect_uri=%s&state=%s&code_challenge=%s&code_challenge_method=S256", | |
| os.Getenv("KEYCLOAK_URL"), | |
| os.Getenv("KEYCLOAK_REALM"), | |
| os.Getenv("KEYCLOAK_CLIENT_ID"), | |
| url.QueryEscape(os.Getenv("KEYCLOAK_REDIRECT_URL")), | |
| state, | |
| codeChallenge, | |
| ) | |
| // Log the URL being redirected to | |
| log.Printf("Redirecting to URL: %s", authURL) | |
| // Perform the actual redirect | |
| http.Redirect(w, r, authURL, http.StatusTemporaryRedirect) | |
| } | |
| func Callback(w http.ResponseWriter, r *http.Request) { | |
| log.Println("Callback handler triggered") | |
| ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) | |
| defer cancel() | |
| code := r.URL.Query().Get("code") | |
| state := r.URL.Query().Get("state") | |
| stateCookie, err := r.Cookie("oauth_state") | |
| if err != nil || state != stateCookie.Value { | |
| http.Error(w, "Invalid state", http.StatusBadRequest) | |
| return | |
| } | |
| pkceCookie, err := r.Cookie("pkce_verifier") | |
| if err != nil { | |
| http.Error(w, "Missing PKCE verifier", http.StatusBadRequest) | |
| return | |
| } | |
| tokenResp, err := exchangeCodeForToken(code, pkceCookie.Value) | |
| if err != nil { | |
| http.Error(w, err.Error(), http.StatusInternalServerError) | |
| return | |
| } | |
| claims, err := verifyIDToken(tokenResp.IDToken) | |
| if err != nil { | |
| http.Error(w, fmt.Sprintf("Invalid ID token: %v", err), http.StatusUnauthorized) | |
| return | |
| } | |
| userID := claims["sub"].(string) | |
| email := claims["email"].(string) | |
| username := claims["preferred_username"].(string) | |
| sid, _ := claims["sid"].(string) | |
| sessionID, _ := utils.SecureRandom(32) | |
| csrfToken, _ := utils.SecureRandom(32) | |
| encAccess, _ := crypto.Encrypt(tokenResp.AccessToken) | |
| encRefresh, _ := crypto.Encrypt(tokenResp.RefreshToken) | |
| expiresAt := time.Now().Add(time.Hour * 24) | |
| accessExpires := time.Now().Add(time.Second * time.Duration(tokenResp.ExpiresIn)) | |
| session := models.Session{ | |
| SessionID: sessionID, | |
| UserID: userID, | |
| Email: email, | |
| Username: username, | |
| KeycloakSID: sid, | |
| EncryptedAccess: encAccess, | |
| EncryptedRefresh: encRefresh, | |
| AccessExpiresAt: accessExpires, | |
| ExpiresAt: expiresAt, | |
| } | |
| if err := repository.Save(ctx, session); err != nil { | |
| http.Error(w, fmt.Sprintf("Failed to save session: %v", err), http.StatusInternalServerError) | |
| return | |
| } | |
| http.SetCookie(w, &http.Cookie{ | |
| Name: "session_id", | |
| Value: sessionID, | |
| Path: "/", | |
| HttpOnly: true, | |
| Secure: true, | |
| SameSite: http.SameSiteLaxMode, | |
| Domain: os.Getenv("COOKIE_DOMAIN"), | |
| }) | |
| http.SetCookie(w, &http.Cookie{ | |
| Name: "XSRF-TOKEN", | |
| Value: csrfToken, | |
| Path: "/", | |
| HttpOnly: false, | |
| Secure: true, | |
| SameSite: http.SameSiteLaxMode, | |
| Domain: os.Getenv("COOKIE_DOMAIN"), | |
| }) | |
| clearCookie(w, "oauth_state") | |
| clearCookie(w, "pkce_verifier") | |
| http.Redirect(w, r, os.Getenv("FRONTEND_ORIGIN"), http.StatusFound) | |
| } | |
| func Me(w http.ResponseWriter, r *http.Request) { | |
| user, err := service.GetCurrentUser(r.Context()) | |
| if err != nil { | |
| http.Error(w, "Unauthorized", http.StatusUnauthorized) | |
| return | |
| } | |
| w.Header().Set("Content-Type", "application/json") | |
| json.NewEncoder(w).Encode(user) | |
| } | |
| func Logout(w http.ResponseWriter, r *http.Request) { | |
| // 1. Delete local session | |
| cookie, err := r.Cookie("session_id") | |
| if err == nil { | |
| _ = repository.DeleteSession(r.Context(), cookie.Value) | |
| } | |
| clearCookie := func(name string, httpOnly bool) { | |
| http.SetCookie(w, &http.Cookie{ | |
| Name: name, | |
| Value: "", | |
| Path: "/", | |
| HttpOnly: httpOnly, | |
| Secure: true, // only HTTPS | |
| MaxAge: -1, | |
| SameSite: http.SameSiteLaxMode, | |
| Domain: os.Getenv("COOKIE_DOMAIN"), | |
| }) | |
| } | |
| clearCookie("session_id", true) | |
| clearCookie("XSRF-TOKEN", false) | |
| logoutURL := fmt.Sprintf("%s/realms/%s/protocol/openid-connect/logout?redirect_uri=%s", | |
| os.Getenv("KEYCLOAK_URL"), | |
| os.Getenv("KEYCLOAK_REALM"), | |
| url.QueryEscape(os.Getenv("FRONTEND_ORIGIN")), | |
| ) | |
| http.Redirect(w, r, logoutURL, http.StatusFound) | |
| } | |
| func clearCookie(w http.ResponseWriter, name string) { | |
| http.SetCookie(w, &http.Cookie{ | |
| Name: name, | |
| Value: "", | |
| MaxAge: -1, | |
| Path: "/", | |
| HttpOnly: true, | |
| Secure: true, | |
| SameSite: http.SameSiteNoneMode, | |
| Domain: os.Getenv("COOKIE_DOMAIN"), | |
| }) | |
| } | |
| type tokenResponse struct { | |
| AccessToken string `json:"access_token"` | |
| RefreshToken string `json:"refresh_token"` | |
| ExpiresIn int `json:"expires_in"` | |
| IDToken string `json:"id_token"` | |
| } | |
| func exchangeCodeForToken(code, verifier string) (*tokenResponse, error) { | |
| tokenURL := fmt.Sprintf("%s/realms/%s/protocol/openid-connect/token", | |
| os.Getenv("KEYCLOAK_URL"), | |
| os.Getenv("KEYCLOAK_REALM"), | |
| ) | |
| data := url.Values{} | |
| data.Set("grant_type", "authorization_code") | |
| data.Set("client_id", os.Getenv("KEYCLOAK_CLIENT_ID")) | |
| data.Set("client_secret", os.Getenv("KEYCLOAK_CLIENT_SECRET")) | |
| data.Set("code", code) | |
| data.Set("redirect_uri", os.Getenv("KEYCLOAK_REDIRECT_URL")) | |
| data.Set("code_verifier", verifier) | |
| resp, err := http.PostForm(tokenURL, data) | |
| if err != nil { | |
| return nil, err | |
| } | |
| defer resp.Body.Close() | |
| body, _ := io.ReadAll(resp.Body) | |
| if resp.StatusCode != http.StatusOK { | |
| return nil, errors.New("token exchange failed: " + string(body)) | |
| } | |
| var tr tokenResponse | |
| if err := json.Unmarshal(body, &tr); err != nil { | |
| return nil, err | |
| } | |
| return &tr, nil | |
| } | |
| func verifyIDToken(idToken string) (map[string]interface{}, error) { | |
| claims, err := service.VerifyIDTokenJWKS(idToken) | |
| if err != nil { | |
| return nil, err | |
| } | |
| return claims, nil | |
| } | |
| func BackchannelLogout(w http.ResponseWriter, r *http.Request) { | |
| ctx := r.Context() | |
| if err := r.ParseForm(); err != nil { | |
| http.Error(w, "Invalid form", http.StatusBadRequest) | |
| return | |
| } | |
| logoutToken := r.FormValue("logout_token") | |
| if logoutToken == "" { | |
| http.Error(w, "Missing logout_token", http.StatusBadRequest) | |
| return | |
| } | |
| claims, err := service.VerifyIDTokenJWKS(logoutToken) | |
| if err != nil { | |
| log.Println("Invalid logout token:", err) | |
| http.Error(w, "Invalid token", http.StatusUnauthorized) | |
| return | |
| } | |
| expectedIssuer := fmt.Sprintf("%s/realms/%s", | |
| os.Getenv("KEYCLOAK_URL"), | |
| os.Getenv("KEYCLOAK_REALM"), | |
| ) | |
| if claims["iss"] != expectedIssuer { | |
| http.Error(w, "Invalid issuer", http.StatusUnauthorized) | |
| return | |
| } | |
| if claims["aud"] != os.Getenv("KEYCLOAK_CLIENT_ID") { | |
| http.Error(w, "Invalid audience", http.StatusUnauthorized) | |
| return | |
| } | |
| events, ok := claims["events"].(map[string]interface{}) | |
| if !ok { | |
| http.Error(w, "Invalid logout token structure", http.StatusUnauthorized) | |
| return | |
| } | |
| if _, ok := events["http://schemas.openid.net/event/backchannel-logout"]; !ok { | |
| http.Error(w, "Not a backchannel logout event", http.StatusUnauthorized) | |
| return | |
| } | |
| sid, ok := claims["sid"].(string) | |
| if !ok || sid == "" { | |
| http.Error(w, "Missing sid", http.StatusBadRequest) | |
| return | |
| } | |
| if err := repository.DeleteByKeycloakSID(ctx, sid); err != nil { | |
| if !errors.Is(err, mongo.ErrNoDocuments) { | |
| log.Println("DB delete error:", err) | |
| http.Error(w, "Internal error", http.StatusInternalServerError) | |
| return | |
| } | |
| } | |
| w.WriteHeader(http.StatusOK) | |
| } | |