import os import json import logging from typing import Dict, Any, Optional from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse from pydantic import BaseModel, validator import httpx import gradio as gr from datetime import datetime, timedelta from collections import defaultdict import hashlib # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize FastAPI app app = FastAPI(title="Secure Gemini API Proxy", version="1.0.0") # Add CORS middleware with restricted origins app.add_middleware( CORSMiddleware, allow_origins=["*"], # In production, specify exact origins allow_credentials=True, allow_methods=["POST", "GET"], # Added GET for static files allow_headers=["*"], ) # Rate limiting storage (in production, use Redis) rate_limit_storage = defaultdict(list) # Request models class GeminiRequest(BaseModel): prompt: str temperature: Optional[float] = 0.7 max_tokens: Optional[int] = 1000 @validator('prompt') def validate_prompt(cls, v): if not v or len(v.strip()) == 0: raise ValueError('Prompt cannot be empty') if len(v) > 10000: # Reasonable limit raise ValueError('Prompt too long') return v.strip() @validator('temperature') def validate_temperature(cls, v): if v is not None and (v < 0 or v > 1): raise ValueError('Temperature must be between 0 and 1') return v @validator('max_tokens') def validate_max_tokens(cls, v): if v is not None and (v < 1 or v > 4000): raise ValueError('Max tokens must be between 1 and 4000') return v class ProxyResponse(BaseModel): response: str status: str timestamp: str # Security functions def get_client_ip(request: Request) -> str: """Extract client IP with proxy header support""" forwarded_for = request.headers.get("X-Forwarded-For") if forwarded_for: return forwarded_for.split(",")[0].strip() return request.client.host if request.client else "unknown" def check_rate_limit(client_ip: str, max_requests: int = 30, window_minutes: int = 1) -> bool: """Simple rate limiting check""" now = datetime.now() window_start = now - timedelta(minutes=window_minutes) # Clean old requests rate_limit_storage[client_ip] = [ req_time for req_time in rate_limit_storage[client_ip] if req_time > window_start ] # Check if under limit if len(rate_limit_storage[client_ip]) >= max_requests: return False # Add current request rate_limit_storage[client_ip].append(now) return True def sanitize_input(text: str) -> str: """Basic input sanitization""" # Remove potential harmful characters dangerous_chars = ['<', '>', '&', '"', "'"] for char in dangerous_chars: text = text.replace(char, '') return text # Get Gemini API key from environment def get_gemini_api_key() -> str: """Retrieve Gemini API key from environment variables""" api_key = os.getenv('GEMINI_API_KEY') if not api_key: logger.error("GEMINI_API_KEY environment variable not found") raise HTTPException(status_code=500, detail="API configuration error") return api_key @app.post("/proxy", response_model=ProxyResponse) async def proxy_gemini(request_data: GeminiRequest, request: Request): """Secure proxy endpoint for Gemini API calls""" try: # Get client IP and check rate limiting client_ip = get_client_ip(request) if not check_rate_limit(client_ip): logger.warning(f"Rate limit exceeded for IP: {client_ip}") raise HTTPException(status_code=429, detail="Rate limit exceeded") # Sanitize input sanitized_prompt = sanitize_input(request_data.prompt) # Get API key api_key = get_gemini_api_key() # Prepare request to Gemini API gemini_url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent?key={api_key}" payload = { "contents": [{ "parts": [{ "text": sanitized_prompt }] }], "generationConfig": { "temperature": request_data.temperature, "maxOutputTokens": request_data.max_tokens, "topP": 0.8, "topK": 40 } } # Make request to Gemini API async with httpx.AsyncClient(timeout=30.0) as client: response = await client.post( gemini_url, json=payload, headers={"Content-Type": "application/json"} ) if response.status_code != 200: logger.error(f"Gemini API error: {response.status_code} - {response.text}") raise HTTPException(status_code=502, detail="External API error") # Parse response gemini_response = response.json() # Extract generated text if 'candidates' in gemini_response and len(gemini_response['candidates']) > 0: generated_text = gemini_response['candidates'][0]['content']['parts'][0]['text'] else: generated_text = "No response generated" # Log successful request (without sensitive data) logger.info(f"Successful proxy request from IP: {client_ip}") return ProxyResponse( response=generated_text, status="success", timestamp=datetime.now().isoformat() ) except HTTPException: raise except Exception as e: logger.error(f"Unexpected error in proxy endpoint: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error") @app.get("/health") async def health_check(): """Health check endpoint""" return {"status": "healthy", "timestamp": datetime.now().isoformat()} # Mount static files from the build directory # This serves the React build files at the root path if os.path.exists("./build"): app.mount("/static", StaticFiles(directory="./build/static"), name="static") @app.get("/") async def serve_react_app(): """Serve the main React app""" return FileResponse("./build/index.html") @app.get("/{full_path:path}") async def serve_react_routes(full_path: str): """Serve React routes (for client-side routing)""" # Skip API routes if full_path.startswith("proxy") or full_path.startswith("health") or full_path.startswith("gradio"): raise HTTPException(status_code=404, detail="Not found") # Check if it's a static file file_path = f"./build/{full_path}" if os.path.isfile(file_path): return FileResponse(file_path) # For all other routes, serve the React app (client-side routing) return FileResponse("./build/index.html") else: # Fallback to Gradio if build directory doesn't exist logger.warning("Build directory not found, falling back to Gradio interface") # Gradio interface def gradio_interface(prompt: str, temperature: float = 0.7, max_tokens: int = 1000): """Gradio wrapper for the proxy function""" try: # This would typically call the proxy endpoint internally # For demo purposes, we'll return a placeholder return f"[Demo Mode] Processed prompt: {prompt[:100]}..." except Exception as e: return f"Error: {str(e)}" # Create Gradio interface iface = gr.Interface( fn=gradio_interface, inputs=[ gr.Textbox(label="Prompt", placeholder="Enter your prompt here...", lines=3), gr.Slider(0, 1, value=0.7, label="Temperature"), gr.Slider(1, 4000, value=1000, step=1, label="Max Tokens") ], outputs=gr.Textbox(label="Response", lines=5), title="Secure Gemini API Proxy", description="A secure proxy interface for Google's Gemini API with rate limiting and input validation." ) # Mount Gradio app at /gradio path to keep it accessible app = gr.mount_gradio_app(app, iface, path="/gradio") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)