PE_TEST / app.py
Johnny-Z's picture
Upload 2 files
c95e677 verified
import gradio as gr
import subprocess
import time
import requests
from openai import OpenAI
from huggingface_hub import login, snapshot_download
import os
import stat
import tarfile
import io
TITLE = "Zero-shot Anime Knowledge Optimizer"
DESCRIPTION = """
"""
hf_token = os.getenv("HF_TOKEN")
if hf_token:
login(token=hf_token)
else:
raise ValueError("environment variable HF_TOKEN not found.")
repo_id = "Johnny-Z/ZAKO-0.6B"
repo_dir = snapshot_download(repo_id, repo_type='dataset')
tar_path = os.path.join(repo_dir, "llama-b7972-bin-ubuntu-x64.tar.gz")
current_dir = os.path.dirname(os.path.abspath(__file__))
with tarfile.open(tar_path, mode="r:gz") as tar:
try:
tar.extractall(path=current_dir, filter="data")
except TypeError:
tar.extractall(path=current_dir)
def _find_llama_server(base_dir: str) -> str:
for root, _, files in os.walk(base_dir):
if "llama-server" in files:
return os.path.join(root, "llama-server")
raise FileNotFoundError(f"未找到 llama-server,可执行文件不在 {base_dir} 及其子目录中")
def get_predicted_tokens_seconds() -> str:
try:
resp = requests.get("http://localhost:8188/metrics", timeout=2)
resp.raise_for_status()
for line in resp.text.splitlines():
if line.startswith("llamacpp:predicted_tokens_seconds"):
parts = line.split()
if len(parts) >= 2:
return parts[-1]
return "N/A"
except requests.RequestException:
return "N/A"
PATH_TO_SERVER_BINARY = _find_llama_server(current_dir)
PATH_TO_MODEL = os.path.join(repo_dir, "ZAKO-0.6B-Q4KM.gguf")
st = os.stat(PATH_TO_SERVER_BINARY)
os.chmod(PATH_TO_SERVER_BINARY, st.st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)
def wait_for_server(url: str, timeout_s: int = 180, interval_s: float = 0.5, process: subprocess.Popen | None = None) -> None:
start = time.time()
while time.time() - start < timeout_s:
if process and process.poll() is not None:
stderr = process.stderr.read().decode("utf-8", errors="ignore") if process.stderr else ""
raise RuntimeError(f"本地推理引擎启动失败,退出码={process.returncode}\n{stderr}")
try:
resp = requests.get(url, timeout=2)
if resp.status_code == 200:
return
except requests.RequestException:
pass
time.sleep(interval_s)
raise TimeoutError("本地推理引擎启动超时")
server_process = subprocess.Popen(
[
PATH_TO_SERVER_BINARY,
"-m", PATH_TO_MODEL,
"--ctx-size", "1280",
"--port", "8188",
"--metrics"
],
stdout=subprocess.DEVNULL,
stderr=subprocess.PIPE
)
print("正在启动本地推理引擎...")
wait_for_server("http://localhost:8188/health", process=server_process)
client = OpenAI(
base_url="http://localhost:8188/v1",
api_key="sk-no-key-required"
)
def chat(question, tags, preference_level):
prompt = f"""
# Role
Act as an image prompt writer. Your goal is to transform inputs into **objective, physical descriptions**. You must convert abstract concepts into concrete scenes, specifying composition, lighting, and textures. Any text to be rendered must be enclosed in double quotes `""` with its typography described. **Strictly avoid** subjective adjectives or quality tags (e.g. "8K", "Masterpiece", "Best Quality"). Output **only** the final visual description.
# User Input
Prompt Quality: {preference_level}
"""
if len(tags.strip()) > 0:
prompt += f"\nTags: {tags}"
if len(question.strip()) > 0:
prompt += f"\nQuestion: {question}"
messages = [
{"role": "user", "content": prompt}
]
response = client.chat.completions.create(
model="ZAKO",
messages=messages,
top_p=0.8,
temperature=0.8,
stream=True
)
output = ""
for chunk in response:
if chunk.choices[0].delta.content:
output += chunk.choices[0].delta.content
predicted_tokens_seconds = get_predicted_tokens_seconds()
yield output, predicted_tokens_seconds
def main():
with gr.Blocks(title=TITLE) as demo:
with gr.Column():
gr.Markdown(
value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>"
)
with gr.Row():
with gr.Column(variant="panel"):
submit = gr.Button(value="Submit", variant="primary", size="lg")
stop = gr.Button(value="Stop", variant="secondary", size="lg")
with gr.Row():
text = gr.Textbox(
label="Simple Description",
value="",
lines=4,
)
with gr.Row():
tags = gr.Textbox(
label="Tags",
value="",
lines=2,
)
with gr.Row():
preference_level = gr.Dropdown(choices=["very high", "high", "normal"], value="very high", label="Prompt Quality")
with gr.Row():
clear = gr.ClearButton(
components=[],
variant="secondary",
size="lg",
)
gr.Markdown(value=DESCRIPTION)
with gr.Column(variant="panel"):
generated_text = gr.Textbox(label="Output", lines=20)
metrics_text = gr.Textbox(label="predicted_tokens_seconds", lines=1, interactive=False)
clear.add([text, tags, generated_text, metrics_text])
stream_evt = submit.click(
chat,
inputs=(text, tags, preference_level),
outputs=(generated_text, metrics_text),
queue=True
)
stop.click(fn=None, inputs=None, outputs=None, cancels=[stream_evt])
demo.queue(max_size=10)
demo.launch()
if __name__ == "__main__":
main()