aditiaiblog's picture
Add static file serving for React build with /proxy endpoint preservation
6897ce3 verified
import os
import json
import logging
from typing import Dict, Any, Optional
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from pydantic import BaseModel, validator
import httpx
import gradio as gr
from datetime import datetime, timedelta
from collections import defaultdict
import hashlib
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(title="Secure Gemini API Proxy", version="1.0.0")
# Add CORS middleware with restricted origins
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, specify exact origins
allow_credentials=True,
allow_methods=["POST", "GET"], # Added GET for static files
allow_headers=["*"],
)
# Rate limiting storage (in production, use Redis)
rate_limit_storage = defaultdict(list)
# Request models
class GeminiRequest(BaseModel):
prompt: str
temperature: Optional[float] = 0.7
max_tokens: Optional[int] = 1000
@validator('prompt')
def validate_prompt(cls, v):
if not v or len(v.strip()) == 0:
raise ValueError('Prompt cannot be empty')
if len(v) > 10000: # Reasonable limit
raise ValueError('Prompt too long')
return v.strip()
@validator('temperature')
def validate_temperature(cls, v):
if v is not None and (v < 0 or v > 1):
raise ValueError('Temperature must be between 0 and 1')
return v
@validator('max_tokens')
def validate_max_tokens(cls, v):
if v is not None and (v < 1 or v > 4000):
raise ValueError('Max tokens must be between 1 and 4000')
return v
class ProxyResponse(BaseModel):
response: str
status: str
timestamp: str
# Security functions
def get_client_ip(request: Request) -> str:
"""Extract client IP with proxy header support"""
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
return forwarded_for.split(",")[0].strip()
return request.client.host if request.client else "unknown"
def check_rate_limit(client_ip: str, max_requests: int = 30, window_minutes: int = 1) -> bool:
"""Simple rate limiting check"""
now = datetime.now()
window_start = now - timedelta(minutes=window_minutes)
# Clean old requests
rate_limit_storage[client_ip] = [
req_time for req_time in rate_limit_storage[client_ip]
if req_time > window_start
]
# Check if under limit
if len(rate_limit_storage[client_ip]) >= max_requests:
return False
# Add current request
rate_limit_storage[client_ip].append(now)
return True
def sanitize_input(text: str) -> str:
"""Basic input sanitization"""
# Remove potential harmful characters
dangerous_chars = ['<', '>', '&', '"', "'"]
for char in dangerous_chars:
text = text.replace(char, '')
return text
# Get Gemini API key from environment
def get_gemini_api_key() -> str:
"""Retrieve Gemini API key from environment variables"""
api_key = os.getenv('GEMINI_API_KEY')
if not api_key:
logger.error("GEMINI_API_KEY environment variable not found")
raise HTTPException(status_code=500, detail="API configuration error")
return api_key
@app.post("/proxy", response_model=ProxyResponse)
async def proxy_gemini(request_data: GeminiRequest, request: Request):
"""Secure proxy endpoint for Gemini API calls"""
try:
# Get client IP and check rate limiting
client_ip = get_client_ip(request)
if not check_rate_limit(client_ip):
logger.warning(f"Rate limit exceeded for IP: {client_ip}")
raise HTTPException(status_code=429, detail="Rate limit exceeded")
# Sanitize input
sanitized_prompt = sanitize_input(request_data.prompt)
# Get API key
api_key = get_gemini_api_key()
# Prepare request to Gemini API
gemini_url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent?key={api_key}"
payload = {
"contents": [{
"parts": [{
"text": sanitized_prompt
}]
}],
"generationConfig": {
"temperature": request_data.temperature,
"maxOutputTokens": request_data.max_tokens,
"topP": 0.8,
"topK": 40
}
}
# Make request to Gemini API
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
gemini_url,
json=payload,
headers={"Content-Type": "application/json"}
)
if response.status_code != 200:
logger.error(f"Gemini API error: {response.status_code} - {response.text}")
raise HTTPException(status_code=502, detail="External API error")
# Parse response
gemini_response = response.json()
# Extract generated text
if 'candidates' in gemini_response and len(gemini_response['candidates']) > 0:
generated_text = gemini_response['candidates'][0]['content']['parts'][0]['text']
else:
generated_text = "No response generated"
# Log successful request (without sensitive data)
logger.info(f"Successful proxy request from IP: {client_ip}")
return ProxyResponse(
response=generated_text,
status="success",
timestamp=datetime.now().isoformat()
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Unexpected error in proxy endpoint: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {"status": "healthy", "timestamp": datetime.now().isoformat()}
# Mount static files from the build directory
# This serves the React build files at the root path
if os.path.exists("./build"):
app.mount("/static", StaticFiles(directory="./build/static"), name="static")
@app.get("/")
async def serve_react_app():
"""Serve the main React app"""
return FileResponse("./build/index.html")
@app.get("/{full_path:path}")
async def serve_react_routes(full_path: str):
"""Serve React routes (for client-side routing)"""
# Skip API routes
if full_path.startswith("proxy") or full_path.startswith("health") or full_path.startswith("gradio"):
raise HTTPException(status_code=404, detail="Not found")
# Check if it's a static file
file_path = f"./build/{full_path}"
if os.path.isfile(file_path):
return FileResponse(file_path)
# For all other routes, serve the React app (client-side routing)
return FileResponse("./build/index.html")
else:
# Fallback to Gradio if build directory doesn't exist
logger.warning("Build directory not found, falling back to Gradio interface")
# Gradio interface
def gradio_interface(prompt: str, temperature: float = 0.7, max_tokens: int = 1000):
"""Gradio wrapper for the proxy function"""
try:
# This would typically call the proxy endpoint internally
# For demo purposes, we'll return a placeholder
return f"[Demo Mode] Processed prompt: {prompt[:100]}..."
except Exception as e:
return f"Error: {str(e)}"
# Create Gradio interface
iface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Textbox(label="Prompt", placeholder="Enter your prompt here...", lines=3),
gr.Slider(0, 1, value=0.7, label="Temperature"),
gr.Slider(1, 4000, value=1000, step=1, label="Max Tokens")
],
outputs=gr.Textbox(label="Response", lines=5),
title="Secure Gemini API Proxy",
description="A secure proxy interface for Google's Gemini API with rate limiting and input validation."
)
# Mount Gradio app at /gradio path to keep it accessible
app = gr.mount_gradio_app(app, iface, path="/gradio")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)