File size: 6,410 Bytes
c95e677 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 | import gradio as gr
import subprocess
import time
import requests
from openai import OpenAI
from huggingface_hub import login, snapshot_download
import os
import stat
import tarfile
import io
TITLE = "Zero-shot Anime Knowledge Optimizer"
DESCRIPTION = """
"""
hf_token = os.getenv("HF_TOKEN")
if hf_token:
login(token=hf_token)
else:
raise ValueError("environment variable HF_TOKEN not found.")
repo_id = "Johnny-Z/ZAKO-0.6B"
repo_dir = snapshot_download(repo_id, repo_type='dataset')
tar_path = os.path.join(repo_dir, "llama-b7972-bin-ubuntu-x64.tar.gz")
current_dir = os.path.dirname(os.path.abspath(__file__))
with tarfile.open(tar_path, mode="r:gz") as tar:
try:
tar.extractall(path=current_dir, filter="data")
except TypeError:
tar.extractall(path=current_dir)
def _find_llama_server(base_dir: str) -> str:
for root, _, files in os.walk(base_dir):
if "llama-server" in files:
return os.path.join(root, "llama-server")
raise FileNotFoundError(f"未找到 llama-server,可执行文件不在 {base_dir} 及其子目录中")
def get_predicted_tokens_seconds() -> str:
try:
resp = requests.get("http://localhost:8188/metrics", timeout=2)
resp.raise_for_status()
for line in resp.text.splitlines():
if line.startswith("llamacpp:predicted_tokens_seconds"):
parts = line.split()
if len(parts) >= 2:
return parts[-1]
return "N/A"
except requests.RequestException:
return "N/A"
PATH_TO_SERVER_BINARY = _find_llama_server(current_dir)
PATH_TO_MODEL = os.path.join(repo_dir, "ZAKO-0.6B-Q4KM.gguf")
st = os.stat(PATH_TO_SERVER_BINARY)
os.chmod(PATH_TO_SERVER_BINARY, st.st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)
def wait_for_server(url: str, timeout_s: int = 180, interval_s: float = 0.5, process: subprocess.Popen | None = None) -> None:
start = time.time()
while time.time() - start < timeout_s:
if process and process.poll() is not None:
stderr = process.stderr.read().decode("utf-8", errors="ignore") if process.stderr else ""
raise RuntimeError(f"本地推理引擎启动失败,退出码={process.returncode}\n{stderr}")
try:
resp = requests.get(url, timeout=2)
if resp.status_code == 200:
return
except requests.RequestException:
pass
time.sleep(interval_s)
raise TimeoutError("本地推理引擎启动超时")
server_process = subprocess.Popen(
[
PATH_TO_SERVER_BINARY,
"-m", PATH_TO_MODEL,
"--ctx-size", "1280",
"--port", "8188",
"--metrics"
],
stdout=subprocess.DEVNULL,
stderr=subprocess.PIPE
)
print("正在启动本地推理引擎...")
wait_for_server("http://localhost:8188/health", process=server_process)
client = OpenAI(
base_url="http://localhost:8188/v1",
api_key="sk-no-key-required"
)
def chat(question, tags, preference_level):
prompt = f"""
# Role
Act as an image prompt writer. Your goal is to transform inputs into **objective, physical descriptions**. You must convert abstract concepts into concrete scenes, specifying composition, lighting, and textures. Any text to be rendered must be enclosed in double quotes `""` with its typography described. **Strictly avoid** subjective adjectives or quality tags (e.g. "8K", "Masterpiece", "Best Quality"). Output **only** the final visual description.
# User Input
Prompt Quality: {preference_level}
"""
if len(tags.strip()) > 0:
prompt += f"\nTags: {tags}"
if len(question.strip()) > 0:
prompt += f"\nQuestion: {question}"
messages = [
{"role": "user", "content": prompt}
]
response = client.chat.completions.create(
model="ZAKO",
messages=messages,
top_p=0.8,
temperature=0.8,
stream=True
)
output = ""
for chunk in response:
if chunk.choices[0].delta.content:
output += chunk.choices[0].delta.content
predicted_tokens_seconds = get_predicted_tokens_seconds()
yield output, predicted_tokens_seconds
def main():
with gr.Blocks(title=TITLE) as demo:
with gr.Column():
gr.Markdown(
value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>"
)
with gr.Row():
with gr.Column(variant="panel"):
submit = gr.Button(value="Submit", variant="primary", size="lg")
stop = gr.Button(value="Stop", variant="secondary", size="lg")
with gr.Row():
text = gr.Textbox(
label="Simple Description",
value="",
lines=4,
)
with gr.Row():
tags = gr.Textbox(
label="Tags",
value="",
lines=2,
)
with gr.Row():
preference_level = gr.Dropdown(choices=["very high", "high", "normal"], value="very high", label="Prompt Quality")
with gr.Row():
clear = gr.ClearButton(
components=[],
variant="secondary",
size="lg",
)
gr.Markdown(value=DESCRIPTION)
with gr.Column(variant="panel"):
generated_text = gr.Textbox(label="Output", lines=20)
metrics_text = gr.Textbox(label="predicted_tokens_seconds", lines=1, interactive=False)
clear.add([text, tags, generated_text, metrics_text])
stream_evt = submit.click(
chat,
inputs=(text, tags, preference_level),
outputs=(generated_text, metrics_text),
queue=True
)
stop.click(fn=None, inputs=None, outputs=None, cancels=[stream_evt])
demo.queue(max_size=10)
demo.launch()
if __name__ == "__main__":
main() |