| from fastapi import FastAPI, HTTPException, Query |
| from fastapi.staticfiles import StaticFiles |
| from fastapi.responses import FileResponse |
| import os |
| import random |
| import string |
| import time |
| import requests |
| import io |
| from PIL import Image |
| import base64 |
| from deep_translator import GoogleTranslator |
| import uvicorn |
| from profanity_check import predict_prob |
|
|
| |
| app = FastAPI() |
|
|
| |
| API_URL = "https://image.pollinations.ai/prompt" |
|
|
|
|
| |
| def is_prompt_safe(prompt: str) -> bool: |
| """ |
| Checks if a prompt is safe to send to the API. |
| Returns True if the prompt is safe, False otherwise. |
| """ |
| |
| |
| OFFENSIVE_THRESHOLD = 0.8 |
| |
| probability_offensive = predict_prob([prompt])[0] |
| print(f"Prompt '{prompt}' has a predicted offensiveness score of: {probability_offensive:.4f}") |
| |
| if probability_offensive > OFFENSIVE_THRESHOLD: |
| return False |
| return True |
|
|
|
|
| def add_random_noise(prompt, noise_level=0.00): |
| if noise_level == 0: |
| return prompt |
| percentage_noise = noise_level * 5 |
| num_noise_chars = int(len(prompt) * (percentage_noise / 100)) |
| if num_noise_chars == 0: |
| return prompt |
| num_noise_chars = min(num_noise_chars, len(prompt)) |
| noise_indices = random.sample(range(len(prompt)), num_noise_chars) |
| prompt_list = list(prompt) |
| noise_chars = list(string.ascii_letters + string.punctuation + ' ' + string.digits) |
| for index in noise_indices: |
| prompt_list[index] = random.choice(noise_chars) |
| return "".join(prompt_list) |
|
|
|
|
| def generate_image(inputs: str, is_negative: str, steps: int, cfg_scale: float, seed: int, noise_level: float): |
| try: |
| translator_to_en = GoogleTranslator(source='auto', target='english') |
| english_inputs = translator_to_en.translate(inputs) |
| prompt_with_noise = add_random_noise(english_inputs, noise_level) |
| request_url = f"{API_URL}/{prompt_with_noise}" |
| params = { |
| "seed": random.randint(0, 100000), |
| "width": 1024, |
| "height": 1024, |
| "nologo": "true", |
| "model": "flux" |
| } |
| response = requests.get(request_url, params=params) |
| response.raise_for_status() |
| image_bytes = response.content |
| image = Image.open(io.BytesIO(image_bytes)) |
| return image |
| except requests.exceptions.HTTPError as e: |
| print(f"HTTP Error fetching image from Pollinations: {e}") |
| print(f"Error Response: {e.response.text}") |
| return None |
| except Exception as e: |
| print(f"An unexpected error occurred in generate_image: {e}") |
| return None |
|
|
|
|
| |
| @app.get("/send_inputs") |
| def send_inputs( |
| inputs: str, |
| noise_level: float, |
| is_negative: str, |
| steps: int = 20, |
| cfg_scale: float = 4.5, |
| seed: int = None |
| ): |
| |
| if not is_prompt_safe(inputs): |
| |
| raise HTTPException( |
| status_code=400, |
| detail="Inappropriate prompt detected. Please try again with a different prompt." |
| ) |
|
|
| |
| try: |
| generated_image = generate_image(inputs, is_negative, steps, cfg_scale, seed, noise_level) |
| if generated_image: |
| image_bytes = io.BytesIO() |
| generated_image.save(image_bytes, format="JPEG") |
| image_base64 = base64.b64encode(image_bytes.getvalue()).decode("utf-8") |
| return {"result_base64": image_base64} |
| else: |
| raise HTTPException(status_code=500, detail="Failed to generate image from the backend service.") |
| except Exception as e: |
| print(f"Error in send_inputs endpoint: {e}") |
| raise HTTPException(status_code=500, detail="An unexpected server error occurred.") |
|
|
|
|
|
|