Spaces:
Paused
Paused
| from fastapi import FastAPI, Request | |
| import threading | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| import requests, os, random | |
| import psycopg2 | |
| import logging | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from runware import Runware, IImageInference | |
| from dotenv import load_dotenv | |
| from openai import OpenAI | |
| from imdb import Cinemagoer | |
| IMGBB_API_KEY = os.getenv("IMGBB_API_KEY") | |
| GIST_URL = os.getenv("GIST_URL") | |
| TMDB_API_KEY = os.getenv("TMDB_API_KEY") | |
| RUNWARE_API_KEY = os.getenv("RUNWARE_API_KEY") | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| OPENAI_API_BASE = os.getenv("OPENAI_API_BASE") | |
| client = OpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_API_BASE+'/v1') | |
| DATA_DIR = "/data" | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Allow requests from Next.js dev server | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| db_params = { | |
| 'dbname': os.getenv('DB_NAME'), | |
| 'user': os.getenv('DB_USER'), | |
| 'password': os.getenv('DB_PASSWORD'), | |
| 'host': os.getenv('DB_HOST'), | |
| 'port': os.getenv('DB_PORT'), | |
| 'sslmode': 'require' | |
| } | |
| class ImageRequest(BaseModel): | |
| prompt: str | |
| width: int | |
| height: int | |
| model: str | |
| number_results: int = 1 | |
| def insert_batch(prompt: str, width: int, height: int, model: str, urls: list[str]) -> int: | |
| conn = None | |
| try: | |
| conn = psycopg2.connect(**db_params) | |
| cur = conn.cursor() | |
| cur.execute( | |
| "INSERT INTO batches (prompt, width, height, model) VALUES (%s, %s, %s, %s) RETURNING id", | |
| (prompt, width, height, model) | |
| ) | |
| batch_id = cur.fetchone()[0] | |
| for url in urls: | |
| cur.execute( | |
| "INSERT INTO images (batch_id, url) VALUES (%s, %s)", | |
| (batch_id, url) | |
| ) | |
| conn.commit() | |
| cur.close() | |
| return batch_id | |
| except (Exception, psycopg2.DatabaseError) as error: | |
| print(f"Error inserting batch: {error}") | |
| raise | |
| finally: | |
| if conn is not None: | |
| conn.close() | |
| def upload_to_imgbb(image_url): | |
| url = "https://api.imgbb.com/1/upload" | |
| payload = { | |
| "key": IMGBB_API_KEY, | |
| "image": image_url, | |
| } | |
| response = requests.post(url, payload) | |
| if response.status_code == 200: | |
| return response.json()['data']['url'] | |
| else: | |
| return None | |
| def upload_image(url): | |
| imgbb_url = upload_to_imgbb(url) | |
| if imgbb_url: | |
| print(f"Uploaded: {url} -> {imgbb_url}") | |
| return imgbb_url | |
| else: | |
| print(f"Failed to upload: {url}") | |
| return None | |
| async def generate_image(request: ImageRequest): | |
| print("Image Generation Request") | |
| try: | |
| runware = Runware(api_key=RUNWARE_API_KEY) | |
| await runware.connect() | |
| request_image = IImageInference( | |
| positivePrompt=request.prompt, | |
| model=request.model, | |
| numberResults=request.number_results, | |
| height=request.height, | |
| width=request.width, | |
| ) | |
| images = await runware.imageInference(requestImage=request_image) | |
| image_urls = [image.imageURL for image in images] | |
| print("Generated Images: ", image_urls) | |
| # imgbb_urls = [] | |
| # with ThreadPoolExecutor(max_workers=10) as executor: | |
| # future_to_url = {executor.submit(upload_image, url): url for url in image_urls} | |
| # for future in as_completed(future_to_url): | |
| # url = future_to_url[future] | |
| # try: | |
| # imgbb_url = future.result() | |
| # if imgbb_url: | |
| # imgbb_urls.append(imgbb_url) | |
| # except Exception as exc: | |
| # print(f"{url} generated an exception: {exc}") | |
| # batch_id = insert_batch(request.prompt, request.width, request.height, request.model, image_urls) | |
| response = { | |
| "batch": { | |
| "prompt": request.prompt, | |
| "width": request.width, | |
| "height": request.height, | |
| "model": request.model, | |
| "images": [{"url": url} for url in image_urls] | |
| } | |
| } | |
| return response | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Failed to generate image: {str(e)}") | |
| # @app.post("/generate-image") | |
| # async def generate_image(request: ImageRequest): | |
| # print("Image Generation Request", request) | |
| # try: | |
| # runware = Runware(api_key=RUNWARE_API_KEY) | |
| # await runware.connect() | |
| # request_image = IImageInference( | |
| # positivePrompt=request.prompt, | |
| # model=request.model, | |
| # numberResults=request.number_results, | |
| # height=request.height, | |
| # width=request.width, | |
| # ) | |
| # images = await runware.imageInference(requestImage=request_image) | |
| # image_urls = [image.imageURL for image in images] | |
| # print("Generated Images: ", image_urls) | |
| # response = { | |
| # "batch": { | |
| # "prompt": request.prompt, | |
| # "width": request.width, | |
| # "height": request.height, | |
| # "model": request.model, | |
| # "images": [{"url": url} for url in image_urls] | |
| # } | |
| # } | |
| # return response | |
| # except Exception as e: | |
| # raise HTTPException(status_code=500, detail=f"Failed to generate image: {str(e)}") | |
| async def get_batches(): | |
| conn = None | |
| try: | |
| conn = psycopg2.connect(**db_params) | |
| cur = conn.cursor() | |
| cur.execute(""" | |
| SELECT b.id, b.prompt, b.width, b.height, b.model, array_agg(i.url) as image_urls, b.created_at | |
| FROM batches b | |
| JOIN images i ON i.batch_id = b.id | |
| GROUP BY b.id, b.prompt, b.width, b.height, b.model, b.created_at | |
| ORDER BY b.created_at DESC | |
| LIMIT 5 | |
| """) | |
| rows = cur.fetchall() | |
| batches = [] | |
| for row in rows: | |
| created_at = row[6] | |
| created_at_iso = created_at.isoformat() if created_at else None | |
| batch = { | |
| "id": row[0], | |
| "prompt": row[1], | |
| "width": row[2], | |
| "height": row[3], | |
| "model": row[4], | |
| "images": [{"url": url} for url in row[5]], | |
| "createdAt": created_at_iso | |
| } | |
| batches.append(batch) | |
| return {"batches": batches} | |
| except (Exception, psycopg2.DatabaseError) as error: | |
| raise HTTPException(status_code=500, detail=str(error)) | |
| finally: | |
| if conn is not None: | |
| conn.close() | |
| def delete_batch(batch_id: int): | |
| conn = None | |
| try: | |
| conn = psycopg2.connect(**db_params) | |
| cur = conn.cursor() | |
| # Delete associated images first | |
| cur.execute("DELETE FROM images WHERE batch_id = %s", (batch_id,)) | |
| # Then delete the batch | |
| cur.execute("DELETE FROM batches WHERE id = %s", (batch_id,)) | |
| conn.commit() | |
| return True | |
| except (Exception, psycopg2.DatabaseError) as error: | |
| print(f"Error deleting batch: {error}") | |
| return False | |
| finally: | |
| if conn is not None: | |
| conn.close() | |
| async def delete_batch_route(id: int): | |
| success = delete_batch(id) | |
| if success: | |
| return {"message": "Batch deleted successfully"} | |
| else: | |
| raise HTTPException(status_code=500, detail="Failed to delete batch") | |
| async def enhance_prompt(request: dict): | |
| try: | |
| prompt = request.get("prompt") | |
| if not prompt: | |
| raise HTTPException(status_code=400, detail="Prompt is required") | |
| response = client.chat.completions.create( | |
| model="gemini-1.5-flash-latest", | |
| messages=[ | |
| {"role": "system", "content": "You are an AI assistant that enhances image generation prompts. Your task is to take a user's prompt and make it more detailed and descriptive, suitable for high-quality image generation."}, | |
| {"role": "user", "content": f"Enhance this image generation prompt: {prompt}. Reply with the enhanced prompt only."} | |
| ] | |
| ) | |
| enhanced_prompt = response.choices[0].message.content | |
| return {"enhancedPrompt": enhanced_prompt} | |
| except Exception as e: | |
| logger.error(f"Error enhancing prompt: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Failed to enhance prompt: {str(e)}") | |
| def greet_json(): | |
| return {"Hello": "World!"} | |
| async def generate_random_number(): | |
| number = random.randint(1, 100) | |
| os.makedirs(DATA_DIR, exist_ok=True) | |
| file_path = os.path.join(DATA_DIR, "random_number.txt") | |
| with open(file_path, "w") as file: | |
| file.write(str(number)) | |
| return {"message": f"Random number {number} saved to {file_path}"} | |
| async def list_directory(): | |
| if not os.path.exists(DATA_DIR): | |
| return {"items": [], "message": "Data directory does not exist yet"} | |
| items = os.listdir(DATA_DIR) | |
| return {"items": items} | |
| async def bing_image(request: Request): | |
| data = await request.json() | |
| prompt = data.get("prompt") | |
| # Placeholder for actual data to return | |
| import asyncio | |
| import json | |
| import re | |
| from urllib.parse import quote | |
| import httpx | |
| TIMEOUT = 200 | |
| TOKEN_FILE = "token.json" | |
| BASE_URL = "https://www.bing.com" | |
| USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0" | |
| class BingDalle: | |
| def __init__(self, auth_cookie): | |
| self.auth_cookie = self._load_config(auth_cookie) | |
| self.headers = { | |
| "User-Agent": USER_AGENT, | |
| "Cookie": f"_U={self.auth_cookie}", | |
| } | |
| self.client = httpx.AsyncClient( | |
| base_url=BASE_URL, headers=self.headers, timeout=TIMEOUT | |
| ) | |
| async def __aenter__(self): | |
| return self | |
| async def __aexit__(self, *_): | |
| await self.client.aclose() | |
| def _load_config(self, auth_cookie): | |
| if auth_cookie: | |
| return auth_cookie | |
| try: | |
| with open(TOKEN_FILE, "r") as file: | |
| config = json.load(file) | |
| return config.get("_U") | |
| except FileNotFoundError: | |
| raise ValueError("Auth cookie not provided and token file not found.") | |
| async def _get_coins(self): | |
| response = await self.client.get("/images/create") | |
| if response.is_success: | |
| coins_match = re.search(r'coins available">(\d+)<', response.text) | |
| return coins_match[1] if coins_match else None | |
| return None | |
| async def _poll_results(self, prompt, request_id): | |
| encoded_prompt = quote(prompt) | |
| result_url = f"/images/create/async/results/{request_id}?q={encoded_prompt}" | |
| while True: | |
| response = await self.client.get(result_url) | |
| if response.is_success and "gir_async" in response.text: | |
| return response | |
| await asyncio.sleep(5) | |
| def _construct_url(self, coins, prompt): | |
| encoded_prompt = quote(prompt) | |
| rt_value = "4" if int(coins) > 0 else "3" | |
| return f"/images/create?q={encoded_prompt}&rt={rt_value}&FORM=GENCRE" | |
| async def _get_request_id(self, prompt, post_url): | |
| data = {"q": prompt, "qs": "ds"} | |
| response = await self.client.post(post_url, data=data, follow_redirects=True) | |
| if response.is_success: | |
| request_id = re.search(r"id=([^&]+)", str(response.url)) | |
| return request_id[1] if request_id else None | |
| return None | |
| def _handle_poll_result(self, poll_results, prompt): | |
| src_urls = list( | |
| { | |
| url.split("?w=")[0] | |
| for url in re.findall(r'src="([^"]+)"', poll_results.text) | |
| if url.startswith(("http", "https")) and url.endswith("ImgGn") | |
| } | |
| ) | |
| return [{"url": src_url} for src_url in src_urls] | |
| async def generate_images(self, prompt): | |
| coins = await self._get_coins() | |
| post_url = self._construct_url(coins, prompt) | |
| request_id = await self._get_request_id(prompt, post_url) | |
| poll_results = await self._poll_results(prompt, request_id) | |
| return self._handle_poll_result(poll_results, prompt) | |
| async def upload_to_imagebb(image_url, api_key): | |
| upload_url = "https://api.imgbb.com/1/upload" | |
| params = { | |
| "key": api_key, | |
| "image": image_url, | |
| } | |
| async with httpx.AsyncClient() as client: | |
| response = await client.post(upload_url, params=params) | |
| if response.is_success: | |
| return response.json()["data"]["url"] | |
| return None | |
| async def main(prompt, auth_cookie, imagebb_api_key): | |
| async with BingDalle(auth_cookie) as bing: | |
| image_urls = await bing.generate_images(prompt) | |
| uploaded_urls = [] | |
| for image in image_urls: | |
| uploaded_url = await upload_to_imagebb(image["url"], imagebb_api_key) | |
| if uploaded_url: | |
| uploaded_urls.append(uploaded_url) | |
| return uploaded_urls | |
| # Usage example | |
| auth_cookie = requests.get(GIST_URL).text | |
| uploaded_urls = await main(prompt, auth_cookie, IMGBB_API_KEY) | |
| return {"images": uploaded_urls} | |
| async def parentalguide(request: Request): | |
| data = await request.json() | |
| imdb_id = data.get("imdb_id") | |
| tmdb_id = data.get("tmdb_id",None) | |
| if tmdb_id is not None: | |
| url = f"https://api.themoviedb.org/3/tv/{tmdb_id}/external_ids" | |
| params = {"api_key": TMDB_API_KEY} | |
| response = requests.get(url, params=params) | |
| if response.status_code == 200: | |
| data = response.json() | |
| imdb_id = data.get("imdb_id").replace("tt","") | |
| print(imdb_id) | |
| ia = Cinemagoer() | |
| movie = ia.get_movie(imdb_id, info=['parents guide']) | |
| return {"pg_guide": movie} | |