Arpr / main.py
Asartb's picture
Update main.py
e3974fa verified
from fastapi import FastAPI, HTTPException, Query
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
import os
import random
import string
import time
import requests
import io
from PIL import Image
import base64
from deep_translator import GoogleTranslator
import uvicorn
from profanity_check import predict_prob # Import the library
# Initialize FastAPI app
app = FastAPI()
# --- Pollinations AI API Configuration ---
API_URL = "https://image.pollinations.ai/prompt"
# --- NEW: Safety Check Function ---
def is_prompt_safe(prompt: str) -> bool:
"""
Checks if a prompt is safe to send to the API.
Returns True if the prompt is safe, False otherwise.
"""
# The predict_prob function returns the probability that the text is offensive.
# We can set a threshold. A value of 0.8 is a reasonably strict starting point.
OFFENSIVE_THRESHOLD = 0.8
probability_offensive = predict_prob([prompt])[0]
print(f"Prompt '{prompt}' has a predicted offensiveness score of: {probability_offensive:.4f}")
if probability_offensive > OFFENSIVE_THRESHOLD:
return False # The prompt is considered unsafe
return True # The prompt is safe
def add_random_noise(prompt, noise_level=0.00):
if noise_level == 0:
return prompt
percentage_noise = noise_level * 5
num_noise_chars = int(len(prompt) * (percentage_noise / 100))
if num_noise_chars == 0:
return prompt
num_noise_chars = min(num_noise_chars, len(prompt))
noise_indices = random.sample(range(len(prompt)), num_noise_chars)
prompt_list = list(prompt)
noise_chars = list(string.ascii_letters + string.punctuation + ' ' + string.digits)
for index in noise_indices:
prompt_list[index] = random.choice(noise_chars)
return "".join(prompt_list)
def generate_image(inputs: str, is_negative: str, steps: int, cfg_scale: float, seed: int, noise_level: float):
try:
translator_to_en = GoogleTranslator(source='auto', target='english')
english_inputs = translator_to_en.translate(inputs)
prompt_with_noise = add_random_noise(english_inputs, noise_level)
request_url = f"{API_URL}/{prompt_with_noise}"
params = {
"seed": random.randint(0, 100000),
"width": 1024,
"height": 1024,
"nologo": "true",
"model": "flux" # Using a free model
}
response = requests.get(request_url, params=params)
response.raise_for_status()
image_bytes = response.content
image = Image.open(io.BytesIO(image_bytes))
return image
except requests.exceptions.HTTPError as e:
print(f"HTTP Error fetching image from Pollinations: {e}")
print(f"Error Response: {e.response.text}")
return None
except Exception as e:
print(f"An unexpected error occurred in generate_image: {e}")
return None
# --- UPDATED: Endpoint with Safety Check ---
@app.get("/send_inputs")
def send_inputs(
inputs: str,
noise_level: float,
is_negative: str,
steps: int = 20,
cfg_scale: float = 4.5,
seed: int = None
):
# --- Step 1: Check the prompt before doing anything else ---
if not is_prompt_safe(inputs):
# If the prompt is not safe, reject the request immediately.
raise HTTPException(
status_code=400, # 400 Bad Request is appropriate here
detail="Inappropriate prompt detected. Please try again with a different prompt."
)
# --- Step 2: If the prompt is safe, proceed with image generation ---
try:
generated_image = generate_image(inputs, is_negative, steps, cfg_scale, seed, noise_level)
if generated_image:
image_bytes = io.BytesIO()
generated_image.save(image_bytes, format="JPEG")
image_base64 = base64.b64encode(image_bytes.getvalue()).decode("utf-8")
return {"result_base64": image_base64}
else:
raise HTTPException(status_code=500, detail="Failed to generate image from the backend service.")
except Exception as e:
print(f"Error in send_inputs endpoint: {e}")
raise HTTPException(status_code=500, detail="An unexpected server error occurred.")