Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import logging | |
| from typing import Dict, Any, Optional | |
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel, validator | |
| import httpx | |
| import gradio as gr | |
| from datetime import datetime, timedelta | |
| from collections import defaultdict | |
| import hashlib | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize FastAPI app | |
| app = FastAPI(title="Secure Gemini API Proxy", version="1.0.0") | |
| # Add CORS middleware with restricted origins | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # In production, specify exact origins | |
| allow_credentials=True, | |
| allow_methods=["POST", "GET"], # Added GET for static files | |
| allow_headers=["*"], | |
| ) | |
| # Rate limiting storage (in production, use Redis) | |
| rate_limit_storage = defaultdict(list) | |
| # Request models | |
| class GeminiRequest(BaseModel): | |
| prompt: str | |
| temperature: Optional[float] = 0.7 | |
| max_tokens: Optional[int] = 1000 | |
| def validate_prompt(cls, v): | |
| if not v or len(v.strip()) == 0: | |
| raise ValueError('Prompt cannot be empty') | |
| if len(v) > 10000: # Reasonable limit | |
| raise ValueError('Prompt too long') | |
| return v.strip() | |
| def validate_temperature(cls, v): | |
| if v is not None and (v < 0 or v > 1): | |
| raise ValueError('Temperature must be between 0 and 1') | |
| return v | |
| def validate_max_tokens(cls, v): | |
| if v is not None and (v < 1 or v > 4000): | |
| raise ValueError('Max tokens must be between 1 and 4000') | |
| return v | |
| class ProxyResponse(BaseModel): | |
| response: str | |
| status: str | |
| timestamp: str | |
| # Security functions | |
| def get_client_ip(request: Request) -> str: | |
| """Extract client IP with proxy header support""" | |
| forwarded_for = request.headers.get("X-Forwarded-For") | |
| if forwarded_for: | |
| return forwarded_for.split(",")[0].strip() | |
| return request.client.host if request.client else "unknown" | |
| def check_rate_limit(client_ip: str, max_requests: int = 30, window_minutes: int = 1) -> bool: | |
| """Simple rate limiting check""" | |
| now = datetime.now() | |
| window_start = now - timedelta(minutes=window_minutes) | |
| # Clean old requests | |
| rate_limit_storage[client_ip] = [ | |
| req_time for req_time in rate_limit_storage[client_ip] | |
| if req_time > window_start | |
| ] | |
| # Check if under limit | |
| if len(rate_limit_storage[client_ip]) >= max_requests: | |
| return False | |
| # Add current request | |
| rate_limit_storage[client_ip].append(now) | |
| return True | |
| def sanitize_input(text: str) -> str: | |
| """Basic input sanitization""" | |
| # Remove potential harmful characters | |
| dangerous_chars = ['<', '>', '&', '"', "'"] | |
| for char in dangerous_chars: | |
| text = text.replace(char, '') | |
| return text | |
| # Get Gemini API key from environment | |
| def get_gemini_api_key() -> str: | |
| """Retrieve Gemini API key from environment variables""" | |
| api_key = os.getenv('GEMINI_API_KEY') | |
| if not api_key: | |
| logger.error("GEMINI_API_KEY environment variable not found") | |
| raise HTTPException(status_code=500, detail="API configuration error") | |
| return api_key | |
| async def proxy_gemini(request_data: GeminiRequest, request: Request): | |
| """Secure proxy endpoint for Gemini API calls""" | |
| try: | |
| # Get client IP and check rate limiting | |
| client_ip = get_client_ip(request) | |
| if not check_rate_limit(client_ip): | |
| logger.warning(f"Rate limit exceeded for IP: {client_ip}") | |
| raise HTTPException(status_code=429, detail="Rate limit exceeded") | |
| # Sanitize input | |
| sanitized_prompt = sanitize_input(request_data.prompt) | |
| # Get API key | |
| api_key = get_gemini_api_key() | |
| # Prepare request to Gemini API | |
| gemini_url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent?key={api_key}" | |
| payload = { | |
| "contents": [{ | |
| "parts": [{ | |
| "text": sanitized_prompt | |
| }] | |
| }], | |
| "generationConfig": { | |
| "temperature": request_data.temperature, | |
| "maxOutputTokens": request_data.max_tokens, | |
| "topP": 0.8, | |
| "topK": 40 | |
| } | |
| } | |
| # Make request to Gemini API | |
| async with httpx.AsyncClient(timeout=30.0) as client: | |
| response = await client.post( | |
| gemini_url, | |
| json=payload, | |
| headers={"Content-Type": "application/json"} | |
| ) | |
| if response.status_code != 200: | |
| logger.error(f"Gemini API error: {response.status_code} - {response.text}") | |
| raise HTTPException(status_code=502, detail="External API error") | |
| # Parse response | |
| gemini_response = response.json() | |
| # Extract generated text | |
| if 'candidates' in gemini_response and len(gemini_response['candidates']) > 0: | |
| generated_text = gemini_response['candidates'][0]['content']['parts'][0]['text'] | |
| else: | |
| generated_text = "No response generated" | |
| # Log successful request (without sensitive data) | |
| logger.info(f"Successful proxy request from IP: {client_ip}") | |
| return ProxyResponse( | |
| response=generated_text, | |
| status="success", | |
| timestamp=datetime.now().isoformat() | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Unexpected error in proxy endpoint: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return {"status": "healthy", "timestamp": datetime.now().isoformat()} | |
| # Mount static files from the build directory | |
| # This serves the React build files at the root path | |
| if os.path.exists("./build"): | |
| app.mount("/static", StaticFiles(directory="./build/static"), name="static") | |
| async def serve_react_app(): | |
| """Serve the main React app""" | |
| return FileResponse("./build/index.html") | |
| async def serve_react_routes(full_path: str): | |
| """Serve React routes (for client-side routing)""" | |
| # Skip API routes | |
| if full_path.startswith("proxy") or full_path.startswith("health") or full_path.startswith("gradio"): | |
| raise HTTPException(status_code=404, detail="Not found") | |
| # Check if it's a static file | |
| file_path = f"./build/{full_path}" | |
| if os.path.isfile(file_path): | |
| return FileResponse(file_path) | |
| # For all other routes, serve the React app (client-side routing) | |
| return FileResponse("./build/index.html") | |
| else: | |
| # Fallback to Gradio if build directory doesn't exist | |
| logger.warning("Build directory not found, falling back to Gradio interface") | |
| # Gradio interface | |
| def gradio_interface(prompt: str, temperature: float = 0.7, max_tokens: int = 1000): | |
| """Gradio wrapper for the proxy function""" | |
| try: | |
| # This would typically call the proxy endpoint internally | |
| # For demo purposes, we'll return a placeholder | |
| return f"[Demo Mode] Processed prompt: {prompt[:100]}..." | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # Create Gradio interface | |
| iface = gr.Interface( | |
| fn=gradio_interface, | |
| inputs=[ | |
| gr.Textbox(label="Prompt", placeholder="Enter your prompt here...", lines=3), | |
| gr.Slider(0, 1, value=0.7, label="Temperature"), | |
| gr.Slider(1, 4000, value=1000, step=1, label="Max Tokens") | |
| ], | |
| outputs=gr.Textbox(label="Response", lines=5), | |
| title="Secure Gemini API Proxy", | |
| description="A secure proxy interface for Google's Gemini API with rate limiting and input validation." | |
| ) | |
| # Mount Gradio app at /gradio path to keep it accessible | |
| app = gr.mount_gradio_app(app, iface, path="/gradio") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |