Spaces:
Build error
Build error
| import gradio as gr | |
| import replicate | |
| import os | |
| import random | |
| import openai | |
| import numpy as np | |
| from PIL import Image | |
| import requests | |
| import io | |
| import base64 | |
| import zipfile | |
| # Set API tokens | |
| os.environ["REPLICATE_API_TOKEN"] = "r8_Brv0MtpmAiqrXrMrziyUXoSHuFV5hqs1Lw4Mo" | |
| # Initialize the Replicate client | |
| rep_client = replicate.Client() | |
| # Set your OpenAI API key | |
| OPENAI_API_KEY = "sk-proj-5iy4bwrqAW8GpguiEawaT3BlbkFJ8p88lLSjOCeDbxWsAOlr" | |
| openai.api_key = OPENAI_API_KEY | |
| # Predefined prompts for the dropdown | |
| predefined_prompts = [ | |
| "Missing bolts on railway track", | |
| "Cracks on railway track", | |
| "Overgrown vegetation near railway track", | |
| "Broken railings on railway bridge", | |
| "Debris on railway track", | |
| "Damaged railway platform" | |
| ] | |
| def ask_rail_defect_question(question, model_name='ft:gpt-3.5-turbo-0125:personal::99NsSAeQ'): | |
| response = openai.ChatCompletion.create( | |
| model=model_name, | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": "The assistant is knowledgeable about rail defects and can answer questions related to them.", | |
| }, | |
| { | |
| "role": "user", | |
| "content": question, | |
| } | |
| ], | |
| ) | |
| return response.choices[0].message['content'] | |
| # Function to generate variations enhanced by the GPT model | |
| def generate_variations(base_prompt, number_of_variations): | |
| locations = ["on the left side", "on the right side", "at the top", "at the bottom", "in the center"] | |
| sizes = ["small", "medium", "large", "tiny", "huge"] | |
| weather_conditions = ["under cold conditions", "during hot weather", "in dry weather", "in humid conditions", "under varying temperatures"] | |
| variations = [] | |
| for _ in range(number_of_variations): | |
| location = random.choice(locations) | |
| size = random.choice(sizes) | |
| weather = random.choice(weather_conditions) | |
| # Enhance the base prompt with the GPT model | |
| enhanced_prompt = base_prompt | |
| full_prompt = f"{enhanced_prompt}, with a {size} defect {location}, observed {weather}." | |
| variations.append(full_prompt) | |
| return variations | |
| def image_to_data_url(image): | |
| buffered = io.BytesIO() | |
| image.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| return f"data:image/png;base64,{img_str}" | |
| # Function to inpaint images | |
| def inpaint_defect(image, prompt, num_images=1): | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| image_data_url = image_to_data_url(image) | |
| images = [] | |
| for _ in range(num_images): | |
| input = { | |
| "input_image": image_data_url, | |
| "instruction_text": prompt, | |
| "scheduler": "K_EULER_ANCESTRAL", | |
| "num_outputs": 1, | |
| "guidance_scale": 7.5, | |
| "num_inference_steps": 100, | |
| "image_guidance_scale": 1.5 | |
| } | |
| prediction = rep_client.predictions.create( | |
| version="10e63b0e6361eb23a0374f4d9ee145824d9d09f7a31dcd70803193ebc7121430", | |
| input=input | |
| ) | |
| prediction.wait() | |
| if prediction.status == "succeeded": | |
| image_url = prediction.output[0] | |
| response = requests.get(image_url) | |
| img = Image.open(io.BytesIO(response.content)) | |
| images.append(img) | |
| else: | |
| images.append(None) | |
| return images | |
| # Function to generate images from prompts | |
| def generate_images(prompts): | |
| images = [] | |
| for prompt in prompts: | |
| try: | |
| prediction = rep_client.predictions.create( | |
| version="ac732df83cea7fff18b8472768c88ad041fa750ff7682a21affe81863cbe77e4", | |
| input={"prompt": prompt, "scheduler": "K_EULER"} | |
| ) | |
| prediction.wait() | |
| if prediction.status == "succeeded" and prediction.output: | |
| images.append(prediction.output[0]) | |
| else: | |
| images.append("Failed to generate image.") | |
| except Exception as e: | |
| images.append(f"Error: {str(e)}") | |
| return images | |
| def process_railway_defects(prompt, number_of_images): | |
| variations = generate_variations(prompt, number_of_images) | |
| images = generate_images(variations) | |
| return images | |
| def download_images_as_zip(images): | |
| zip_buffer = io.BytesIO() | |
| with zipfile.ZipFile(zip_buffer, 'w') as zf: | |
| for i, img in enumerate(images): | |
| img_buffer = io.BytesIO() | |
| img.save(img_buffer, format='PNG') | |
| img_buffer.seek(0) | |
| zf.writestr(f'image_{i + 1}.png', img_buffer.read()) | |
| zip_buffer.seek(0) | |
| return zip_buffer | |
| # UI creation | |
| with gr.Blocks() as app: | |
| with gr.Tabs("Prompt Input"): | |
| with gr.Tab("Current Defects"): | |
| with gr.Row(): | |
| prompt_input = gr.Dropdown(choices=predefined_prompts, label="Select a prompt") | |
| number_input_dropdown = gr.Number(label="Number of images to generate", value=1, minimum=1, maximum=10) | |
| submit_button_dropdown = gr.Button("Generate") | |
| image_outputs_dropdown = gr.Gallery() | |
| def on_submit_click_dropdown(prompt, number_of_images): | |
| images = process_railway_defects(prompt, number_of_images) | |
| return images | |
| submit_button_dropdown.click( | |
| fn=on_submit_click_dropdown, | |
| inputs=[prompt_input, number_input_dropdown], | |
| outputs=image_outputs_dropdown | |
| ) | |
| with gr.Tab("Custom Defect"): | |
| with gr.Row(): | |
| custom_prompt_input = gr.Textbox(label="Custom Defect") | |
| number_input_custom = gr.Number(label="Number of images to generate", value=1, minimum=1, maximum=10) | |
| submit_button_custom = gr.Button("Generate") | |
| image_outputs_custom = gr.Gallery() | |
| def on_submit_click_custom(custom_prompt, number_of_images): | |
| images = process_railway_defects(custom_prompt, number_of_images) | |
| return images | |
| submit_button_custom.click( | |
| fn=on_submit_click_custom, | |
| inputs=[custom_prompt_input, number_input_custom], | |
| outputs=image_outputs_custom | |
| ) | |
| with gr.Tab("Inpaint Defect"): | |
| with gr.Row(): | |
| image_input = gr.Image(label="Upload Image") | |
| inpaint_prompt_input = gr.Textbox(label="Defect Description") | |
| number_input_inpaint = gr.Number(label="Number of images to generate", value=1, minimum=1, maximum=10) | |
| submit_button_inpaint = gr.Button("Inpaint Defect") | |
| inpainted_image_output = gr.Gallery() | |
| download_button = gr.Button("Download Images as Zip") | |
| def on_submit_click_inpaint(image, inpaint_prompt, number_of_images): | |
| inpainted_images = inpaint_defect(image, inpaint_prompt, num_images=number_of_images) | |
| return inpainted_images | |
| def on_download_click(images): | |
| zip_buffer = download_images_as_zip(images) | |
| return zip_buffer | |
| submit_button_inpaint.click( | |
| fn=on_submit_click_inpaint, | |
| inputs=[image_input, inpaint_prompt_input, number_input_inpaint], | |
| outputs=inpainted_image_output | |
| ) | |
| download_button.click( | |
| fn=on_download_click, | |
| inputs=inpainted_image_output, | |
| outputs=gr.File(label="Download Zip") | |
| ) | |
| if __name__ == "__main__": | |
| app.launch() | |