File size: 8,540 Bytes
f112ff8
 
 
 
 
 
6897ce3
 
f112ff8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6897ce3
f112ff8
 
 
 
 
 
 
 
 
 
 
6897ce3
f112ff8
 
 
 
 
 
 
6897ce3
f112ff8
 
 
 
 
6897ce3
f112ff8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6897ce3
f112ff8
 
 
 
 
 
 
 
 
 
 
6897ce3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f112ff8
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
import os
import json
import logging
from typing import Dict, Any, Optional
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from pydantic import BaseModel, validator
import httpx
import gradio as gr
from datetime import datetime, timedelta
from collections import defaultdict
import hashlib

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize FastAPI app
app = FastAPI(title="Secure Gemini API Proxy", version="1.0.0")

# Add CORS middleware with restricted origins
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # In production, specify exact origins
    allow_credentials=True,
    allow_methods=["POST", "GET"],  # Added GET for static files
    allow_headers=["*"],
)

# Rate limiting storage (in production, use Redis)
rate_limit_storage = defaultdict(list)

# Request models
class GeminiRequest(BaseModel):
    prompt: str
    temperature: Optional[float] = 0.7
    max_tokens: Optional[int] = 1000

    @validator('prompt')
    def validate_prompt(cls, v):
        if not v or len(v.strip()) == 0:
            raise ValueError('Prompt cannot be empty')
        if len(v) > 10000:  # Reasonable limit
            raise ValueError('Prompt too long')
        return v.strip()

    @validator('temperature')
    def validate_temperature(cls, v):
        if v is not None and (v < 0 or v > 1):
            raise ValueError('Temperature must be between 0 and 1')
        return v

    @validator('max_tokens')
    def validate_max_tokens(cls, v):
        if v is not None and (v < 1 or v > 4000):
            raise ValueError('Max tokens must be between 1 and 4000')
        return v

class ProxyResponse(BaseModel):
    response: str
    status: str
    timestamp: str

# Security functions
def get_client_ip(request: Request) -> str:
    """Extract client IP with proxy header support"""
    forwarded_for = request.headers.get("X-Forwarded-For")
    if forwarded_for:
        return forwarded_for.split(",")[0].strip()
    return request.client.host if request.client else "unknown"

def check_rate_limit(client_ip: str, max_requests: int = 30, window_minutes: int = 1) -> bool:
    """Simple rate limiting check"""
    now = datetime.now()
    window_start = now - timedelta(minutes=window_minutes)
    
    # Clean old requests
    rate_limit_storage[client_ip] = [
        req_time for req_time in rate_limit_storage[client_ip]
        if req_time > window_start
    ]
    
    # Check if under limit
    if len(rate_limit_storage[client_ip]) >= max_requests:
        return False
    
    # Add current request
    rate_limit_storage[client_ip].append(now)
    return True

def sanitize_input(text: str) -> str:
    """Basic input sanitization"""
    # Remove potential harmful characters
    dangerous_chars = ['<', '>', '&', '"', "'"]
    for char in dangerous_chars:
        text = text.replace(char, '')
    return text

# Get Gemini API key from environment
def get_gemini_api_key() -> str:
    """Retrieve Gemini API key from environment variables"""
    api_key = os.getenv('GEMINI_API_KEY')
    if not api_key:
        logger.error("GEMINI_API_KEY environment variable not found")
        raise HTTPException(status_code=500, detail="API configuration error")
    return api_key

@app.post("/proxy", response_model=ProxyResponse)
async def proxy_gemini(request_data: GeminiRequest, request: Request):
    """Secure proxy endpoint for Gemini API calls"""
    try:
        # Get client IP and check rate limiting
        client_ip = get_client_ip(request)
        if not check_rate_limit(client_ip):
            logger.warning(f"Rate limit exceeded for IP: {client_ip}")
            raise HTTPException(status_code=429, detail="Rate limit exceeded")
        
        # Sanitize input
        sanitized_prompt = sanitize_input(request_data.prompt)
        
        # Get API key
        api_key = get_gemini_api_key()
        
        # Prepare request to Gemini API
        gemini_url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent?key={api_key}"
        
        payload = {
            "contents": [{
                "parts": [{
                    "text": sanitized_prompt
                }]
            }],
            "generationConfig": {
                "temperature": request_data.temperature,
                "maxOutputTokens": request_data.max_tokens,
                "topP": 0.8,
                "topK": 40
            }
        }
        
        # Make request to Gemini API
        async with httpx.AsyncClient(timeout=30.0) as client:
            response = await client.post(
                gemini_url,
                json=payload,
                headers={"Content-Type": "application/json"}
            )
        
        if response.status_code != 200:
            logger.error(f"Gemini API error: {response.status_code} - {response.text}")
            raise HTTPException(status_code=502, detail="External API error")
        
        # Parse response
        gemini_response = response.json()
        
        # Extract generated text
        if 'candidates' in gemini_response and len(gemini_response['candidates']) > 0:
            generated_text = gemini_response['candidates'][0]['content']['parts'][0]['text']
        else:
            generated_text = "No response generated"
        
        # Log successful request (without sensitive data)
        logger.info(f"Successful proxy request from IP: {client_ip}")
        
        return ProxyResponse(
            response=generated_text,
            status="success",
            timestamp=datetime.now().isoformat()
        )
        
    except HTTPException:
        raise
    except Exception as e:
        logger.error(f"Unexpected error in proxy endpoint: {str(e)}")
        raise HTTPException(status_code=500, detail="Internal server error")

@app.get("/health")
async def health_check():
    """Health check endpoint"""
    return {"status": "healthy", "timestamp": datetime.now().isoformat()}

# Mount static files from the build directory
# This serves the React build files at the root path
if os.path.exists("./build"):
    app.mount("/static", StaticFiles(directory="./build/static"), name="static")
    
    @app.get("/")
    async def serve_react_app():
        """Serve the main React app"""
        return FileResponse("./build/index.html")
    
    @app.get("/{full_path:path}")
    async def serve_react_routes(full_path: str):
        """Serve React routes (for client-side routing)"""
        # Skip API routes
        if full_path.startswith("proxy") or full_path.startswith("health") or full_path.startswith("gradio"):
            raise HTTPException(status_code=404, detail="Not found")
        
        # Check if it's a static file
        file_path = f"./build/{full_path}"
        if os.path.isfile(file_path):
            return FileResponse(file_path)
        
        # For all other routes, serve the React app (client-side routing)
        return FileResponse("./build/index.html")
else:
    # Fallback to Gradio if build directory doesn't exist
    logger.warning("Build directory not found, falling back to Gradio interface")
    
    # Gradio interface
    def gradio_interface(prompt: str, temperature: float = 0.7, max_tokens: int = 1000):
        """Gradio wrapper for the proxy function"""
        try:
            # This would typically call the proxy endpoint internally
            # For demo purposes, we'll return a placeholder
            return f"[Demo Mode] Processed prompt: {prompt[:100]}..."
        except Exception as e:
            return f"Error: {str(e)}"
    
    # Create Gradio interface
    iface = gr.Interface(
        fn=gradio_interface,
        inputs=[
            gr.Textbox(label="Prompt", placeholder="Enter your prompt here...", lines=3),
            gr.Slider(0, 1, value=0.7, label="Temperature"),
            gr.Slider(1, 4000, value=1000, step=1, label="Max Tokens")
        ],
        outputs=gr.Textbox(label="Response", lines=5),
        title="Secure Gemini API Proxy",
        description="A secure proxy interface for Google's Gemini API with rate limiting and input validation."
    )
    
    # Mount Gradio app at /gradio path to keep it accessible
    app = gr.mount_gradio_app(app, iface, path="/gradio")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)