File size: 4,374 Bytes
9c5c12c
 
 
 
 
10cca52
 
 
 
 
5355110
9c5c12c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6a57ba
9c5c12c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6a57ba
 
 
 
 
 
 
 
 
 
9c5c12c
 
 
a6a57ba
 
 
9c5c12c
 
 
a6a57ba
9c5c12c
 
 
a6a57ba
 
 
 
 
 
10cca52
a6a57ba
 
5355110
9c5c12c
5355110
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import gradio as gr
import replicate
import os
import random
import openai
import numpy as np
from PIL import Image
import requests
import io
import base64
import zipfile

# Set API tokens
os.environ["REPLICATE_API_TOKEN"] = "r8_Brv0MtpmAiqrXrMrziyUXoSHuFV5hqs1Lw4Mo"
# Initialize the Replicate client
rep_client = replicate.Client()

# Set your OpenAI API key
OPENAI_API_KEY = "sk-proj-5iy4bwrqAW8GpguiEawaT3BlbkFJ8p88lLSjOCeDbxWsAOlr"
openai.api_key = OPENAI_API_KEY

predefined_prompts = [
    "Missing bolts on railway track",
    "Cracks on railway track",
    "Overgrown vegetation near railway track",
    "Broken railings on railway bridge",
    "Debris on railway track",
    "Damaged railway platform"
]

def ask_rail_defect_question(question, model_name='ft:gpt-3.5-turbo-0125:personal::99NsSAeQ'):
    response = openai.ChatCompletion.create(
        model=model_name,
        messages=[
            {
                "role": "system",
                "content": "The assistant is knowledgeable about rail defects and can answer questions related to them.",
            },
            {
                "role": "user",
                "content": question,
            }
        ],
    )
    return response.choices[0].message['content']

def generate_variations(base_prompt, number_of_variations):
    locations = ["on the left side", "on the right side", "at the top", "at the bottom", "in the center"]
    sizes = ["small", "medium", "large", "tiny", "huge"]
    weather_conditions = ["under cold conditions", "during hot weather", "in dry weather", "in humid conditions", "under varying temperatures"]

    variations = []
    for _ in range(number_of_variations):
        location = random.choice(locations)
        size = random.choice(sizes)
        weather = random.choice(weather_conditions)
        
        full_prompt = f"{base_prompt}, with a {size} defect {location}, observed {weather}."
        variations.append(full_prompt)
    return variations

def generate_images(prompts):
    images = []
    for prompt in prompts:
        try:
            prediction = rep_client.predictions.create(
                version="ac732df83cea7fff18b8472768c88ad041fa750ff7682a21affe81863cbe77e4",
                input={"prompt": prompt, "scheduler": "K_EULER"}
            )
            prediction.wait()
            if prediction.status == "succeeded" and prediction.output:
                images.append(prediction.output[0])
            else:
                images.append("Failed to generate image.")
        except Exception as e:
            images.append(f"Error: {str(e)}")
    return images

# UI creation
with gr.Blocks() as app:
    with gr.Tabs("Prompt Input"):
        with gr.Tab("Generate Images"):
            prompt_input = gr.Dropdown(choices=predefined_prompts, label="Select a defect prompt")
            number_input = gr.Number(label="Number of images", value=1, minimum=1, maximum=10)
            generate_button = gr.Button("Generate")
            gallery = gr.Gallery(label="Generated Images")
            
            generate_button.click(
                fn=lambda prompt, num: generate_images(generate_variations(prompt, num)),
                inputs=[prompt_input, number_input],
                outputs=gallery
            )

        with gr.Tab("Custom Defect"):
            custom_prompt_input = gr.Textbox(label="Custom Defect")
            number_input_custom = gr.Number(label="Number of images to generate", value=1, minimum=1, maximum=10)
            submit_button_custom = gr.Button("Generate")
            image_outputs_custom = gr.Gallery()

            submit_button_custom.click(
                fn=lambda prompt, num: generate_images(generate_variations(prompt, num)),
                inputs=[custom_prompt_input, number_input_custom],
                outputs=image_outputs_custom
            )
            
    feedback_input = gr.Textbox(label="Enter your feedback", placeholder="Write your feedback here...")
    feedback_button = gr.Button("Submit Feedback")
    feedback_result = gr.Textbox(label="System Response", interactive=False)
    refresh_button = gr.Button("Refresh Page")
    

    feedback_button.click(lambda x: ask_rail_defect_question(x), inputs=feedback_input, outputs=feedback_result)
    refresh_button.click(lambda: gr.update(reload_browser=True))

if __name__ == "__main__":
    app.launch()