go / internal /auth /handler /auth_handler.go
learnifymedhub's picture
Update internal/auth/handler/auth_handler.go
6301200 verified
// internal/auth/handler/auth_handler.go
package handler
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"os"
"time"
"log"
"server/internal/auth/models"
"server/internal/auth/repository"
"server/internal/auth/service"
"server/internal/shared/crypto"
"server/internal/shared/utils"
"go.mongodb.org/mongo-driver/mongo"
)
type LoggingResponseWriter struct {
http.ResponseWriter
StatusCode int
}
func (rw *LoggingResponseWriter) WriteHeader(statusCode int) {
rw.StatusCode = statusCode
rw.ResponseWriter.WriteHeader(statusCode)
}
func Login(w http.ResponseWriter, r *http.Request) {
log.Println("Login handler triggered")
// Generate secure random values
state, _ := utils.SecureRandom(32)
codeVerifier, _ := utils.SecureRandom(64)
codeChallenge := crypto.SHA256Base64URL(codeVerifier)
// Set cookies
http.SetCookie(w, &http.Cookie{
Name: "oauth_state",
Value: state,
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteNoneMode,
Path: "/",
MaxAge: 300,
})
log.Println("Cookies set 1")
http.SetCookie(w, &http.Cookie{
Name: "pkce_verifier",
Value: codeVerifier,
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteNoneMode,
Path: "/",
MaxAge: 300,
})
log.Println("Cookies set 2")
// Construct the authorization URL for Keycloak
authURL := fmt.Sprintf("%s/realms/%s/protocol/openid-connect/auth?client_id=%s&response_type=code&scope=openid profile email&redirect_uri=%s&state=%s&code_challenge=%s&code_challenge_method=S256",
os.Getenv("KEYCLOAK_URL"),
os.Getenv("KEYCLOAK_REALM"),
os.Getenv("KEYCLOAK_CLIENT_ID"),
url.QueryEscape(os.Getenv("KEYCLOAK_REDIRECT_URL")),
state,
codeChallenge,
)
// Log the URL being redirected to
log.Printf("Redirecting to URL: %s", authURL)
// Perform the actual redirect
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
}
func Callback(w http.ResponseWriter, r *http.Request) {
log.Println("Callback handler triggered")
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
defer cancel()
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
stateCookie, err := r.Cookie("oauth_state")
if err != nil || state != stateCookie.Value {
http.Error(w, "Invalid state", http.StatusBadRequest)
return
}
pkceCookie, err := r.Cookie("pkce_verifier")
if err != nil {
http.Error(w, "Missing PKCE verifier", http.StatusBadRequest)
return
}
tokenResp, err := exchangeCodeForToken(code, pkceCookie.Value)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
claims, err := verifyIDToken(tokenResp.IDToken)
if err != nil {
http.Error(w, fmt.Sprintf("Invalid ID token: %v", err), http.StatusUnauthorized)
return
}
userID := claims["sub"].(string)
email := claims["email"].(string)
username := claims["preferred_username"].(string)
sid, _ := claims["sid"].(string)
sessionID, _ := utils.SecureRandom(32)
csrfToken, _ := utils.SecureRandom(32)
encAccess, _ := crypto.Encrypt(tokenResp.AccessToken)
encRefresh, _ := crypto.Encrypt(tokenResp.RefreshToken)
expiresAt := time.Now().Add(time.Hour * 24)
accessExpires := time.Now().Add(time.Second * time.Duration(tokenResp.ExpiresIn))
session := models.Session{
SessionID: sessionID,
UserID: userID,
Email: email,
Username: username,
KeycloakSID: sid,
EncryptedAccess: encAccess,
EncryptedRefresh: encRefresh,
AccessExpiresAt: accessExpires,
ExpiresAt: expiresAt,
}
if err := repository.Save(ctx, session); err != nil {
http.Error(w, fmt.Sprintf("Failed to save session: %v", err), http.StatusInternalServerError)
return
}
http.SetCookie(w, &http.Cookie{
Name: "session_id",
Value: sessionID,
Path: "/",
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,
Domain: os.Getenv("COOKIE_DOMAIN"),
})
http.SetCookie(w, &http.Cookie{
Name: "XSRF-TOKEN",
Value: csrfToken,
Path: "/",
HttpOnly: false,
Secure: true,
SameSite: http.SameSiteLaxMode,
Domain: os.Getenv("COOKIE_DOMAIN"),
})
clearCookie(w, "oauth_state")
clearCookie(w, "pkce_verifier")
http.Redirect(w, r, os.Getenv("FRONTEND_ORIGIN"), http.StatusFound)
}
func Me(w http.ResponseWriter, r *http.Request) {
user, err := service.GetCurrentUser(r.Context())
if err != nil {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(user)
}
func Logout(w http.ResponseWriter, r *http.Request) {
// 1. Delete local session
cookie, err := r.Cookie("session_id")
if err == nil {
_ = repository.DeleteSession(r.Context(), cookie.Value)
}
clearCookie := func(name string, httpOnly bool) {
http.SetCookie(w, &http.Cookie{
Name: name,
Value: "",
Path: "/",
HttpOnly: httpOnly,
Secure: true, // only HTTPS
MaxAge: -1,
SameSite: http.SameSiteLaxMode,
Domain: os.Getenv("COOKIE_DOMAIN"),
})
}
clearCookie("session_id", true)
clearCookie("XSRF-TOKEN", false)
logoutURL := fmt.Sprintf("%s/realms/%s/protocol/openid-connect/logout?redirect_uri=%s",
os.Getenv("KEYCLOAK_URL"),
os.Getenv("KEYCLOAK_REALM"),
url.QueryEscape(os.Getenv("FRONTEND_ORIGIN")),
)
http.Redirect(w, r, logoutURL, http.StatusFound)
}
func clearCookie(w http.ResponseWriter, name string) {
http.SetCookie(w, &http.Cookie{
Name: name,
Value: "",
MaxAge: -1,
Path: "/",
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteNoneMode,
Domain: os.Getenv("COOKIE_DOMAIN"),
})
}
type tokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
IDToken string `json:"id_token"`
}
func exchangeCodeForToken(code, verifier string) (*tokenResponse, error) {
tokenURL := fmt.Sprintf("%s/realms/%s/protocol/openid-connect/token",
os.Getenv("KEYCLOAK_URL"),
os.Getenv("KEYCLOAK_REALM"),
)
data := url.Values{}
data.Set("grant_type", "authorization_code")
data.Set("client_id", os.Getenv("KEYCLOAK_CLIENT_ID"))
data.Set("client_secret", os.Getenv("KEYCLOAK_CLIENT_SECRET"))
data.Set("code", code)
data.Set("redirect_uri", os.Getenv("KEYCLOAK_REDIRECT_URL"))
data.Set("code_verifier", verifier)
resp, err := http.PostForm(tokenURL, data)
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, errors.New("token exchange failed: " + string(body))
}
var tr tokenResponse
if err := json.Unmarshal(body, &tr); err != nil {
return nil, err
}
return &tr, nil
}
func verifyIDToken(idToken string) (map[string]interface{}, error) {
claims, err := service.VerifyIDTokenJWKS(idToken)
if err != nil {
return nil, err
}
return claims, nil
}
func BackchannelLogout(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if err := r.ParseForm(); err != nil {
http.Error(w, "Invalid form", http.StatusBadRequest)
return
}
logoutToken := r.FormValue("logout_token")
if logoutToken == "" {
http.Error(w, "Missing logout_token", http.StatusBadRequest)
return
}
claims, err := service.VerifyIDTokenJWKS(logoutToken)
if err != nil {
log.Println("Invalid logout token:", err)
http.Error(w, "Invalid token", http.StatusUnauthorized)
return
}
expectedIssuer := fmt.Sprintf("%s/realms/%s",
os.Getenv("KEYCLOAK_URL"),
os.Getenv("KEYCLOAK_REALM"),
)
if claims["iss"] != expectedIssuer {
http.Error(w, "Invalid issuer", http.StatusUnauthorized)
return
}
if claims["aud"] != os.Getenv("KEYCLOAK_CLIENT_ID") {
http.Error(w, "Invalid audience", http.StatusUnauthorized)
return
}
events, ok := claims["events"].(map[string]interface{})
if !ok {
http.Error(w, "Invalid logout token structure", http.StatusUnauthorized)
return
}
if _, ok := events["http://schemas.openid.net/event/backchannel-logout"]; !ok {
http.Error(w, "Not a backchannel logout event", http.StatusUnauthorized)
return
}
sid, ok := claims["sid"].(string)
if !ok || sid == "" {
http.Error(w, "Missing sid", http.StatusBadRequest)
return
}
if err := repository.DeleteByKeycloakSID(ctx, sid); err != nil {
if !errors.Is(err, mongo.ErrNoDocuments) {
log.Println("DB delete error:", err)
http.Error(w, "Internal error", http.StatusInternalServerError)
return
}
}
w.WriteHeader(http.StatusOK)
}