ParthSadaria commited on
Commit
dc7e7dc
·
verified ·
1 Parent(s): a3bd0ac

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +54 -48
main.py CHANGED
@@ -1,26 +1,36 @@
1
  import os
2
  import re
3
- from dotenv import load_dotenv
4
- from fastapi import FastAPI, HTTPException, Request, Depends, Security, Query
5
- from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse, FileResponse, PlainTextResponse
6
- from fastapi.security import APIKeyHeader
7
- from pydantic import BaseModel
8
- import httpx
9
- from functools import lru_cache
10
- from pathlib import Path
11
  import json
12
- import datetime
13
  import time
 
14
  import asyncio
15
- from starlette.status import HTTP_403_FORBIDDEN
 
 
 
 
 
 
16
  import cloudscraper
17
  from concurrent.futures import ThreadPoolExecutor
18
- import uvloop
 
 
 
 
 
 
 
 
 
 
19
  from fastapi.middleware.gzip import GZipMiddleware
20
  from starlette.middleware.cors import CORSMiddleware
 
 
 
21
  import contextlib
22
- import requests
23
- from typing import List, Dict, Any, Optional, Union # Import Optional and other typing helpers
24
 
25
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
26
 
@@ -43,6 +53,9 @@ app.add_middleware(
43
  allow_methods=["*"],
44
  allow_headers=["*"],
45
  )
 
 
 
46
 
47
  @lru_cache(maxsize=1)
48
  def get_env_vars():
@@ -660,52 +673,45 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
660
  except Exception as e:
661
  print(f"An unexpected error occurred during non-streaming chat completion proxy: {e}")
662
  raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
663
-
664
- @app.post("/images/generations")
665
- async def create_image(payload: ImageGenerationPayload, request: Request, authenticated: bool = Depends(verify_api_key)):
 
 
 
 
 
666
  if not server_status:
667
- raise HTTPException(
668
- status_code=503,
669
- content={"message": "Server is under maintenance. Please try again later."}
670
- )
671
 
672
- if payload.model not in supported_image_models:
673
- raise HTTPException(
674
- status_code=400,
675
- detail=f"Model '{payload.model}' is not supported for image generation. Supported models are: {', '.join(supported_image_models)}"
676
- )
677
 
678
- usage_tracker.record_request(request=request, model=payload.model, endpoint="/images/generations")
 
 
 
679
 
680
- api_payload = {
681
- "model": payload.model,
682
- "prompt": payload.prompt,
683
- "size": payload.size,
684
- "n": payload.number
685
- }
686
-
687
- target_api_url = get_env_vars().get('new_img')
688
- if not target_api_url:
689
- raise HTTPException(status_code=500, detail="Image generation API endpoint (NEW_IMG) not configured.")
690
 
691
  try:
692
- client = get_async_client()
693
- response = await client.post(target_api_url, json=api_payload)
694
 
695
- response.raise_for_status()
 
696
 
697
- return JSONResponse(content=response.json())
 
 
 
698
 
699
  except httpx.TimeoutException:
700
- raise HTTPException(status_code=504, detail="Image generation request timed out.")
701
- except httpx.RequestError as e:
702
- raise HTTPException(status_code=502, detail=f"Error connecting to image generation service: {e}")
703
- except httpx.HTTPStatusError as e:
704
- error_detail = e.response.json().get("detail", f"Image generation failed with status code: {e.response.status_code}")
705
- raise HTTPException(status_code=e.response.status_code, detail=error_detail)
706
  except Exception as e:
707
- print(f"An unexpected error occurred during image generation: {e}")
708
- raise HTTPException(status_code=500, detail=f"An unexpected error occurred during image generation: {e}")
709
 
710
  @app.get("/usage")
711
  async def get_usage_json(days: int = 7):
 
1
  import os
2
  import re
 
 
 
 
 
 
 
 
3
  import json
 
4
  import time
5
+ import datetime
6
  import asyncio
7
+ import hashlib
8
+ from pathlib import Path
9
+ from functools import lru_cache
10
+ from typing import List, Dict, Any, Optional, Union
11
+
12
+ import httpx
13
+ import requests
14
  import cloudscraper
15
  from concurrent.futures import ThreadPoolExecutor
16
+
17
+ from dotenv import load_dotenv
18
+ from fastapi import FastAPI, HTTPException, Request, Depends, Security, Query, APIRouter
19
+ from fastapi.responses import (
20
+ StreamingResponse,
21
+ HTMLResponse,
22
+ JSONResponse,
23
+ FileResponse,
24
+ PlainTextResponse,
25
+ )
26
+ from fastapi.security import APIKeyHeader
27
  from fastapi.middleware.gzip import GZipMiddleware
28
  from starlette.middleware.cors import CORSMiddleware
29
+ from starlette.status import HTTP_403_FORBIDDEN
30
+
31
+ import uvloop
32
  import contextlib
33
+
 
34
 
35
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
36
 
 
53
  allow_methods=["*"],
54
  allow_headers=["*"],
55
  )
56
+ @lru_cache(maxsize=256)
57
+ def cached_url(url: str):
58
+ return url
59
 
60
  @lru_cache(maxsize=1)
61
  def get_env_vars():
 
673
  except Exception as e:
674
  print(f"An unexpected error occurred during non-streaming chat completion proxy: {e}")
675
  raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
676
+ router = APIRouter()
677
+
678
+ @router.get("/images/{prompt:path}")
679
+ async def create_image(
680
+ prompt: str,
681
+ request: Request,
682
+ authenticated: bool = Depends(verify_api_key)
683
+ ):
684
  if not server_status:
685
+ raise HTTPException(status_code=503, detail="Server is under maintenance.")
 
 
 
686
 
687
+ # forward all GET params
688
+ query = request.url.query
 
 
 
689
 
690
+ base = "https://image.pollinations.ai/prompt/"
691
+ final_url = f"{base}{prompt}"
692
+ if query:
693
+ final_url += f"?{query}"
694
 
695
+ # caching
696
+ final_url = cached_url(final_url)
 
 
 
 
 
 
 
 
697
 
698
  try:
699
+ client = httpx.AsyncClient(timeout=60)
700
+ resp = await client.get(final_url)
701
 
702
+ if resp.status_code != 200:
703
+ raise HTTPException(status_code=resp.status_code, detail="Image generation failed.")
704
 
705
+ return StreamingResponse(
706
+ resp.aiter_bytes(),
707
+ media_type="image/jpeg"
708
+ )
709
 
710
  except httpx.TimeoutException:
711
+ raise HTTPException(status_code=504, detail="Image generation timeout.")
712
+
 
 
 
 
713
  except Exception as e:
714
+ raise HTTPException(status_code=500, detail=f"Unexpected error: {e}")
 
715
 
716
  @app.get("/usage")
717
  async def get_usage_json(days: int = 7):