Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import requests | |
| import io | |
| from PIL import Image | |
| import json | |
| import os | |
| import shutil | |
| import logging | |
| import math | |
| from tqdm import tqdm | |
| import time | |
| from diffusers import DiffusionPipeline | |
| def run_lora(lora, prompt, neg_prompt, progress=gr.Progress(track_tqdm=True)): | |
| print(f"Inside run_lora, lora: {lora}, prompt: {prompt}, neg_prompt: {neg_prompt}") | |
| api_url = f"https://api-inference.huggingface.co/models/{lora}" | |
| payload = { | |
| "inputs": f"{prompt}", | |
| "parameters":{"negative_prompt": "bad art, ugly, watermark, deformed"}, | |
| } | |
| # Add a print statement to display the API request | |
| print(f"API Request: {api_url}") | |
| print(f"API Payload: {payload}") | |
| error_count = 0 | |
| pbar = tqdm(total=None, desc="Loading model") | |
| while(True): | |
| response = requests.post(api_url, json=payload) | |
| if response.status_code == 200: | |
| return Image.open(io.BytesIO(response.content)) | |
| elif response.status_code == 503: | |
| #503 is triggered when the model is doing cold boot. It also gives you a time estimate from when the model is loaded but it is not super precise | |
| time.sleep(1) | |
| pbar.update(1) | |
| elif response.status_code == 500 and error_count < 5: | |
| print(response.content) | |
| time.sleep(1) | |
| error_count += 1 | |
| continue | |
| else: | |
| logging.error(f"API Error: {response.status_code}") | |
| raise gr.Error("API Error: Unable to fetch the image.") # Raise a Gradio error here | |
| app = gr.Interface( | |
| run_lora, | |
| [ | |
| gr.Textbox(label="LoRA model card", show_label=False, lines=1, max_lines=1, placeholder="Type the LoRA model card here."), | |
| gr.Textbox(label="Prompt", show_label=False, placeholder="Type a prompt after selecting a LoRA."), | |
| gr.Textbox(label="Negative Prompt", show_label=False, placeholder="Type negative prompt here."), | |
| # gr.Button("Run") | |
| ], | |
| "image", | |
| # examples=[ | |
| # [2, "cat", ["Japan", "Pakistan"], "park", ["ate", "swam"], True], | |
| # [4, "dog", ["Japan"], "zoo", ["ate", "swam"], False], | |
| # [10, "bird", ["USA", "Pakistan"], "road", ["ran"], False], | |
| # [8, "cat", ["Pakistan"], "zoo", ["ate"], True], | |
| # ] | |
| ) | |
| app.launch() | |