Spaces:
Sleeping
Sleeping
File size: 8,540 Bytes
f112ff8 6897ce3 f112ff8 6897ce3 f112ff8 6897ce3 f112ff8 6897ce3 f112ff8 6897ce3 f112ff8 6897ce3 f112ff8 6897ce3 f112ff8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 | import os
import json
import logging
from typing import Dict, Any, Optional
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from pydantic import BaseModel, validator
import httpx
import gradio as gr
from datetime import datetime, timedelta
from collections import defaultdict
import hashlib
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(title="Secure Gemini API Proxy", version="1.0.0")
# Add CORS middleware with restricted origins
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, specify exact origins
allow_credentials=True,
allow_methods=["POST", "GET"], # Added GET for static files
allow_headers=["*"],
)
# Rate limiting storage (in production, use Redis)
rate_limit_storage = defaultdict(list)
# Request models
class GeminiRequest(BaseModel):
prompt: str
temperature: Optional[float] = 0.7
max_tokens: Optional[int] = 1000
@validator('prompt')
def validate_prompt(cls, v):
if not v or len(v.strip()) == 0:
raise ValueError('Prompt cannot be empty')
if len(v) > 10000: # Reasonable limit
raise ValueError('Prompt too long')
return v.strip()
@validator('temperature')
def validate_temperature(cls, v):
if v is not None and (v < 0 or v > 1):
raise ValueError('Temperature must be between 0 and 1')
return v
@validator('max_tokens')
def validate_max_tokens(cls, v):
if v is not None and (v < 1 or v > 4000):
raise ValueError('Max tokens must be between 1 and 4000')
return v
class ProxyResponse(BaseModel):
response: str
status: str
timestamp: str
# Security functions
def get_client_ip(request: Request) -> str:
"""Extract client IP with proxy header support"""
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
return forwarded_for.split(",")[0].strip()
return request.client.host if request.client else "unknown"
def check_rate_limit(client_ip: str, max_requests: int = 30, window_minutes: int = 1) -> bool:
"""Simple rate limiting check"""
now = datetime.now()
window_start = now - timedelta(minutes=window_minutes)
# Clean old requests
rate_limit_storage[client_ip] = [
req_time for req_time in rate_limit_storage[client_ip]
if req_time > window_start
]
# Check if under limit
if len(rate_limit_storage[client_ip]) >= max_requests:
return False
# Add current request
rate_limit_storage[client_ip].append(now)
return True
def sanitize_input(text: str) -> str:
"""Basic input sanitization"""
# Remove potential harmful characters
dangerous_chars = ['<', '>', '&', '"', "'"]
for char in dangerous_chars:
text = text.replace(char, '')
return text
# Get Gemini API key from environment
def get_gemini_api_key() -> str:
"""Retrieve Gemini API key from environment variables"""
api_key = os.getenv('GEMINI_API_KEY')
if not api_key:
logger.error("GEMINI_API_KEY environment variable not found")
raise HTTPException(status_code=500, detail="API configuration error")
return api_key
@app.post("/proxy", response_model=ProxyResponse)
async def proxy_gemini(request_data: GeminiRequest, request: Request):
"""Secure proxy endpoint for Gemini API calls"""
try:
# Get client IP and check rate limiting
client_ip = get_client_ip(request)
if not check_rate_limit(client_ip):
logger.warning(f"Rate limit exceeded for IP: {client_ip}")
raise HTTPException(status_code=429, detail="Rate limit exceeded")
# Sanitize input
sanitized_prompt = sanitize_input(request_data.prompt)
# Get API key
api_key = get_gemini_api_key()
# Prepare request to Gemini API
gemini_url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent?key={api_key}"
payload = {
"contents": [{
"parts": [{
"text": sanitized_prompt
}]
}],
"generationConfig": {
"temperature": request_data.temperature,
"maxOutputTokens": request_data.max_tokens,
"topP": 0.8,
"topK": 40
}
}
# Make request to Gemini API
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
gemini_url,
json=payload,
headers={"Content-Type": "application/json"}
)
if response.status_code != 200:
logger.error(f"Gemini API error: {response.status_code} - {response.text}")
raise HTTPException(status_code=502, detail="External API error")
# Parse response
gemini_response = response.json()
# Extract generated text
if 'candidates' in gemini_response and len(gemini_response['candidates']) > 0:
generated_text = gemini_response['candidates'][0]['content']['parts'][0]['text']
else:
generated_text = "No response generated"
# Log successful request (without sensitive data)
logger.info(f"Successful proxy request from IP: {client_ip}")
return ProxyResponse(
response=generated_text,
status="success",
timestamp=datetime.now().isoformat()
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Unexpected error in proxy endpoint: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {"status": "healthy", "timestamp": datetime.now().isoformat()}
# Mount static files from the build directory
# This serves the React build files at the root path
if os.path.exists("./build"):
app.mount("/static", StaticFiles(directory="./build/static"), name="static")
@app.get("/")
async def serve_react_app():
"""Serve the main React app"""
return FileResponse("./build/index.html")
@app.get("/{full_path:path}")
async def serve_react_routes(full_path: str):
"""Serve React routes (for client-side routing)"""
# Skip API routes
if full_path.startswith("proxy") or full_path.startswith("health") or full_path.startswith("gradio"):
raise HTTPException(status_code=404, detail="Not found")
# Check if it's a static file
file_path = f"./build/{full_path}"
if os.path.isfile(file_path):
return FileResponse(file_path)
# For all other routes, serve the React app (client-side routing)
return FileResponse("./build/index.html")
else:
# Fallback to Gradio if build directory doesn't exist
logger.warning("Build directory not found, falling back to Gradio interface")
# Gradio interface
def gradio_interface(prompt: str, temperature: float = 0.7, max_tokens: int = 1000):
"""Gradio wrapper for the proxy function"""
try:
# This would typically call the proxy endpoint internally
# For demo purposes, we'll return a placeholder
return f"[Demo Mode] Processed prompt: {prompt[:100]}..."
except Exception as e:
return f"Error: {str(e)}"
# Create Gradio interface
iface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Textbox(label="Prompt", placeholder="Enter your prompt here...", lines=3),
gr.Slider(0, 1, value=0.7, label="Temperature"),
gr.Slider(1, 4000, value=1000, step=1, label="Max Tokens")
],
outputs=gr.Textbox(label="Response", lines=5),
title="Secure Gemini API Proxy",
description="A secure proxy interface for Google's Gemini API with rate limiting and input validation."
)
# Mount Gradio app at /gradio path to keep it accessible
app = gr.mount_gradio_app(app, iface, path="/gradio")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860) |