File size: 2,296 Bytes
b36d8fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4a0eee
b36d8fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import gradio as gr
import requests
import io
from PIL import Image
import json
import os
import shutil
import logging
import math 
from tqdm import tqdm
import time
from diffusers import DiffusionPipeline


def run_lora(lora, prompt, neg_prompt, progress=gr.Progress(track_tqdm=True)):
    print(f"Inside run_lora, lora: {lora}, prompt: {prompt}, neg_prompt: {neg_prompt}")

    api_url = f"https://api-inference.huggingface.co/models/{lora}"
    payload = {
        "inputs": f"{prompt}",
        "parameters":{"negative_prompt": "bad art, ugly, watermark, deformed"},
    }
    
    # Add a print statement to display the API request
    print(f"API Request: {api_url}")
    print(f"API Payload: {payload}")

    error_count = 0
    pbar = tqdm(total=None, desc="Loading model")
    while(True):
        response = requests.post(api_url, json=payload)
        if response.status_code == 200:
            return Image.open(io.BytesIO(response.content))
        elif response.status_code == 503:
            #503 is triggered when the model is doing cold boot. It also gives you a time estimate from when the model is loaded but it is not super precise
            time.sleep(1)
            pbar.update(1)
        elif response.status_code == 500 and error_count < 5:
            print(response.content)
            time.sleep(1)
            error_count += 1
            continue
        else:
            logging.error(f"API Error: {response.status_code}")
            raise gr.Error("API Error: Unable to fetch the image.")  # Raise a Gradio error here


app = gr.Interface(
    run_lora,
    [
        gr.Textbox(label="LoRA model card", show_label=False, lines=1, max_lines=1, placeholder="Type the LoRA model card here."),
        gr.Textbox(label="Prompt", show_label=False, placeholder="Type a prompt after selecting a LoRA."),
        gr.Textbox(label="Negative Prompt", show_label=False, placeholder="Type negative prompt here."),
        # gr.Button("Run")
    ],
    "image",
    # examples=[
    #     [2, "cat", ["Japan", "Pakistan"], "park", ["ate", "swam"], True],
    #     [4, "dog", ["Japan"], "zoo", ["ate", "swam"], False],
    #     [10, "bird", ["USA", "Pakistan"], "road", ["ran"], False],
    #     [8, "cat", ["Pakistan"], "zoo", ["ate"], True],
    # ]
)

app.launch()