File size: 6,149 Bytes
a4b70d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
from __future__ import annotations
import json
import uuid
from ...typing import AsyncResult, Messages
from ...providers.response import ImageResponse, ImagePreview, JsonConversation, Reasoning
from ...requests import StreamSession
from ...image import use_aspect_ratio
from ...errors import ResponseError
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_media_prompt
from .DeepseekAI_JanusPro7b import get_zerogpu_token
from .raise_for_status import raise_for_status
class BlackForestLabs_Flux1Dev(AsyncGeneratorProvider, ProviderModelMixin):
label = "BlackForestLabs Flux-1-Dev"
url = "https://black-forest-labs-flux-1-dev.hf.space"
space = "black-forest-labs/FLUX.1-dev"
referer = f"{url}/?__theme=light"
working = True
default_model = 'black-forest-labs-flux-1-dev'
default_image_model = default_model
model_aliases = {"flux-dev": default_image_model, "flux": default_image_model}
image_models = list(model_aliases.keys())
models = image_models
@classmethod
def run(cls, method: str, session: StreamSession, conversation: JsonConversation, data: list = None):
headers = {
"accept": "application/json",
"content-type": "application/json",
"x-zerogpu-token": conversation.zerogpu_token,
"x-zerogpu-uuid": conversation.zerogpu_uuid,
"referer": cls.referer,
}
if method == "post":
return session.post(f"{cls.url}/gradio_api/queue/join?__theme=light", **{
"headers": {k: v for k, v in headers.items() if v is not None},
"json": {"data": data,"event_data":None,"fn_index":2,"trigger_id":4,"session_hash":conversation.session_hash}
})
return session.get(f"{cls.url}/gradio_api/queue/data?session_hash={conversation.session_hash}", **{
"headers": {
"accept": "text/event-stream",
"content-type": "application/json",
"referer": cls.referer,
}
})
@classmethod
async def create_async_generator(
cls,
model: str,
messages: Messages,
prompt: str = None,
proxy: str = None,
aspect_ratio: str = "1:1",
width: int = None,
height: int = None,
guidance_scale: float = 3.5,
num_inference_steps: int = 28,
seed: int = 0,
randomize_seed: bool = True,
cookies: dict = None,
api_key: str = None,
zerogpu_uuid: str = "[object Object]",
**kwargs
) -> AsyncResult:
async with StreamSession(impersonate="chrome", proxy=proxy) as session:
prompt = format_media_prompt(messages, prompt)
data = use_aspect_ratio({"width": width, "height": height}, aspect_ratio)
data = [prompt, seed, randomize_seed, data.get("width"), data.get("height"), guidance_scale, num_inference_steps]
conversation = JsonConversation(zerogpu_token=api_key, zerogpu_uuid=zerogpu_uuid, session_hash=uuid.uuid4().hex)
if conversation.zerogpu_token is None:
conversation.zerogpu_uuid, conversation.zerogpu_token = await get_zerogpu_token(cls.space, session, conversation, cookies)
async with cls.run(f"post", session, conversation, data) as response:
await raise_for_status(response)
assert (await response.json()).get("event_id")
async with cls.run("get", session, conversation) as event_response:
await raise_for_status(event_response)
async for chunk in event_response.iter_lines():
if chunk.startswith(b"data: "):
try:
json_data = json.loads(chunk[6:])
if json_data is None:
continue
if json_data.get('msg') == 'log':
yield Reasoning(status=json_data["log"])
if json_data.get('msg') == 'progress':
if 'progress_data' in json_data:
if json_data['progress_data']:
progress = json_data['progress_data'][0]
yield Reasoning(status=f"{progress['desc']} {progress['index']}/{progress['length']}")
else:
yield Reasoning(status=f"Generating")
elif json_data.get('msg') == 'process_generating':
for item in json_data['output']['data'][0]:
if isinstance(item, dict) and "url" in item:
yield ImagePreview(item["url"], prompt)
elif isinstance(item, list) and len(item) > 2 and "url" in item[1]:
yield ImagePreview(item[2], prompt)
elif json_data.get('msg') == 'process_completed':
if 'output' in json_data and 'error' in json_data['output']:
json_data['output']['error'] = json_data['output']['error'].split(" <a ")[0]
raise ResponseError(json_data['output']['error'])
if 'output' in json_data and 'data' in json_data['output']:
yield Reasoning(status="")
if len(json_data['output']['data']) > 0:
yield ImageResponse(json_data['output']['data'][0]["url"], prompt)
break
except (json.JSONDecodeError, KeyError, TypeError) as e:
raise RuntimeError(f"Failed to parse message: {chunk.decode(errors='replace')}", e)
|