Rahatara's picture
Rename app.py to app2.py
0803f4e verified
raw
history blame
7.63 kB
import gradio as gr
import replicate
import os
import random
import openai
import numpy as np
from PIL import Image
import requests
import io
import base64
import zipfile
# Set API tokens
os.environ["REPLICATE_API_TOKEN"] = "r8_Brv0MtpmAiqrXrMrziyUXoSHuFV5hqs1Lw4Mo"
# Initialize the Replicate client
rep_client = replicate.Client()
# Set your OpenAI API key
OPENAI_API_KEY = "sk-proj-5iy4bwrqAW8GpguiEawaT3BlbkFJ8p88lLSjOCeDbxWsAOlr"
openai.api_key = OPENAI_API_KEY
# Predefined prompts for the dropdown
predefined_prompts = [
"Missing bolts on railway track",
"Cracks on railway track",
"Overgrown vegetation near railway track",
"Broken railings on railway bridge",
"Debris on railway track",
"Damaged railway platform"
]
def ask_rail_defect_question(question, model_name='ft:gpt-3.5-turbo-0125:personal::99NsSAeQ'):
response = openai.ChatCompletion.create(
model=model_name,
messages=[
{
"role": "system",
"content": "The assistant is knowledgeable about rail defects and can answer questions related to them.",
},
{
"role": "user",
"content": question,
}
],
)
return response.choices[0].message['content']
# Function to generate variations enhanced by the GPT model
def generate_variations(base_prompt, number_of_variations):
locations = ["on the left side", "on the right side", "at the top", "at the bottom", "in the center"]
sizes = ["small", "medium", "large", "tiny", "huge"]
weather_conditions = ["under cold conditions", "during hot weather", "in dry weather", "in humid conditions", "under varying temperatures"]
variations = []
for _ in range(number_of_variations):
location = random.choice(locations)
size = random.choice(sizes)
weather = random.choice(weather_conditions)
# Enhance the base prompt with the GPT model
enhanced_prompt = base_prompt
full_prompt = f"{enhanced_prompt}, with a {size} defect {location}, observed {weather}."
variations.append(full_prompt)
return variations
def image_to_data_url(image):
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return f"data:image/png;base64,{img_str}"
# Function to inpaint images
def inpaint_defect(image, prompt, num_images=1):
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
image_data_url = image_to_data_url(image)
images = []
for _ in range(num_images):
input = {
"input_image": image_data_url,
"instruction_text": prompt,
"scheduler": "K_EULER_ANCESTRAL",
"num_outputs": 1,
"guidance_scale": 7.5,
"num_inference_steps": 100,
"image_guidance_scale": 1.5
}
prediction = rep_client.predictions.create(
version="10e63b0e6361eb23a0374f4d9ee145824d9d09f7a31dcd70803193ebc7121430",
input=input
)
prediction.wait()
if prediction.status == "succeeded":
image_url = prediction.output[0]
response = requests.get(image_url)
img = Image.open(io.BytesIO(response.content))
images.append(img)
else:
images.append(None)
return images
# Function to generate images from prompts
def generate_images(prompts):
images = []
for prompt in prompts:
try:
prediction = rep_client.predictions.create(
version="ac732df83cea7fff18b8472768c88ad041fa750ff7682a21affe81863cbe77e4",
input={"prompt": prompt, "scheduler": "K_EULER"}
)
prediction.wait()
if prediction.status == "succeeded" and prediction.output:
images.append(prediction.output[0])
else:
images.append("Failed to generate image.")
except Exception as e:
images.append(f"Error: {str(e)}")
return images
def process_railway_defects(prompt, number_of_images):
variations = generate_variations(prompt, number_of_images)
images = generate_images(variations)
return images
def download_images_as_zip(images):
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w') as zf:
for i, img in enumerate(images):
img_buffer = io.BytesIO()
img.save(img_buffer, format='PNG')
img_buffer.seek(0)
zf.writestr(f'image_{i + 1}.png', img_buffer.read())
zip_buffer.seek(0)
return zip_buffer
# UI creation
with gr.Blocks() as app:
with gr.Tabs("Prompt Input"):
with gr.Tab("Current Defects"):
with gr.Row():
prompt_input = gr.Dropdown(choices=predefined_prompts, label="Select a prompt")
number_input_dropdown = gr.Number(label="Number of images to generate", value=1, minimum=1, maximum=10)
submit_button_dropdown = gr.Button("Generate")
image_outputs_dropdown = gr.Gallery()
def on_submit_click_dropdown(prompt, number_of_images):
images = process_railway_defects(prompt, number_of_images)
return images
submit_button_dropdown.click(
fn=on_submit_click_dropdown,
inputs=[prompt_input, number_input_dropdown],
outputs=image_outputs_dropdown
)
with gr.Tab("Custom Defect"):
with gr.Row():
custom_prompt_input = gr.Textbox(label="Custom Defect")
number_input_custom = gr.Number(label="Number of images to generate", value=1, minimum=1, maximum=10)
submit_button_custom = gr.Button("Generate")
image_outputs_custom = gr.Gallery()
def on_submit_click_custom(custom_prompt, number_of_images):
images = process_railway_defects(custom_prompt, number_of_images)
return images
submit_button_custom.click(
fn=on_submit_click_custom,
inputs=[custom_prompt_input, number_input_custom],
outputs=image_outputs_custom
)
with gr.Tab("Inpaint Defect"):
with gr.Row():
image_input = gr.Image(label="Upload Image")
inpaint_prompt_input = gr.Textbox(label="Defect Description")
number_input_inpaint = gr.Number(label="Number of images to generate", value=1, minimum=1, maximum=10)
submit_button_inpaint = gr.Button("Inpaint Defect")
inpainted_image_output = gr.Gallery()
download_button = gr.Button("Download Images as Zip")
def on_submit_click_inpaint(image, inpaint_prompt, number_of_images):
inpainted_images = inpaint_defect(image, inpaint_prompt, num_images=number_of_images)
return inpainted_images
def on_download_click(images):
zip_buffer = download_images_as_zip(images)
return zip_buffer
submit_button_inpaint.click(
fn=on_submit_click_inpaint,
inputs=[image_input, inpaint_prompt_input, number_input_inpaint],
outputs=inpainted_image_output
)
download_button.click(
fn=on_download_click,
inputs=inpainted_image_output,
outputs=gr.File(label="Download Zip")
)
if __name__ == "__main__":
app.launch()