| import gradio as gr | |
| import subprocess | |
| import time | |
| import requests | |
| from openai import OpenAI | |
| from huggingface_hub import login, snapshot_download | |
| import os | |
| import stat | |
| import tarfile | |
| import io | |
| TITLE = "Zero-shot Anime Knowledge Optimizer" | |
| DESCRIPTION = """ | |
| """ | |
| hf_token = os.getenv("HF_TOKEN") | |
| if hf_token: | |
| login(token=hf_token) | |
| else: | |
| raise ValueError("environment variable HF_TOKEN not found.") | |
| repo_id = "Johnny-Z/ZAKO-0.6B" | |
| repo_dir = snapshot_download(repo_id, repo_type='dataset') | |
| tar_path = os.path.join(repo_dir, "llama-b7972-bin-ubuntu-x64.tar.gz") | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| with tarfile.open(tar_path, mode="r:gz") as tar: | |
| try: | |
| tar.extractall(path=current_dir, filter="data") | |
| except TypeError: | |
| tar.extractall(path=current_dir) | |
| def _find_llama_server(base_dir: str) -> str: | |
| for root, _, files in os.walk(base_dir): | |
| if "llama-server" in files: | |
| return os.path.join(root, "llama-server") | |
| raise FileNotFoundError(f"未找到 llama-server,可执行文件不在 {base_dir} 及其子目录中") | |
| def get_predicted_tokens_seconds() -> str: | |
| try: | |
| resp = requests.get("http://localhost:8188/metrics", timeout=2) | |
| resp.raise_for_status() | |
| for line in resp.text.splitlines(): | |
| if line.startswith("llamacpp:predicted_tokens_seconds"): | |
| parts = line.split() | |
| if len(parts) >= 2: | |
| return parts[-1] | |
| return "N/A" | |
| except requests.RequestException: | |
| return "N/A" | |
| PATH_TO_SERVER_BINARY = _find_llama_server(current_dir) | |
| PATH_TO_MODEL = os.path.join(repo_dir, "ZAKO-0.6B-Q4KM.gguf") | |
| st = os.stat(PATH_TO_SERVER_BINARY) | |
| os.chmod(PATH_TO_SERVER_BINARY, st.st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH) | |
| def wait_for_server(url: str, timeout_s: int = 180, interval_s: float = 0.5, process: subprocess.Popen | None = None) -> None: | |
| start = time.time() | |
| while time.time() - start < timeout_s: | |
| if process and process.poll() is not None: | |
| stderr = process.stderr.read().decode("utf-8", errors="ignore") if process.stderr else "" | |
| raise RuntimeError(f"本地推理引擎启动失败,退出码={process.returncode}\n{stderr}") | |
| try: | |
| resp = requests.get(url, timeout=2) | |
| if resp.status_code == 200: | |
| return | |
| except requests.RequestException: | |
| pass | |
| time.sleep(interval_s) | |
| raise TimeoutError("本地推理引擎启动超时") | |
| server_process = subprocess.Popen( | |
| [ | |
| PATH_TO_SERVER_BINARY, | |
| "-m", PATH_TO_MODEL, | |
| "--ctx-size", "1280", | |
| "--port", "8188", | |
| "--metrics" | |
| ], | |
| stdout=subprocess.DEVNULL, | |
| stderr=subprocess.PIPE | |
| ) | |
| print("正在启动本地推理引擎...") | |
| wait_for_server("http://localhost:8188/health", process=server_process) | |
| client = OpenAI( | |
| base_url="http://localhost:8188/v1", | |
| api_key="sk-no-key-required" | |
| ) | |
| def chat(question, tags, preference_level): | |
| prompt = f""" | |
| # Role | |
| Act as an image prompt writer. Your goal is to transform inputs into **objective, physical descriptions**. You must convert abstract concepts into concrete scenes, specifying composition, lighting, and textures. Any text to be rendered must be enclosed in double quotes `""` with its typography described. **Strictly avoid** subjective adjectives or quality tags (e.g. "8K", "Masterpiece", "Best Quality"). Output **only** the final visual description. | |
| # User Input | |
| Prompt Quality: {preference_level} | |
| """ | |
| if len(tags.strip()) > 0: | |
| prompt += f"\nTags: {tags}" | |
| if len(question.strip()) > 0: | |
| prompt += f"\nQuestion: {question}" | |
| messages = [ | |
| {"role": "user", "content": prompt} | |
| ] | |
| response = client.chat.completions.create( | |
| model="ZAKO", | |
| messages=messages, | |
| top_p=0.8, | |
| temperature=0.8, | |
| stream=True | |
| ) | |
| output = "" | |
| for chunk in response: | |
| if chunk.choices[0].delta.content: | |
| output += chunk.choices[0].delta.content | |
| predicted_tokens_seconds = get_predicted_tokens_seconds() | |
| yield output, predicted_tokens_seconds | |
| def main(): | |
| with gr.Blocks(title=TITLE) as demo: | |
| with gr.Column(): | |
| gr.Markdown( | |
| value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(variant="panel"): | |
| submit = gr.Button(value="Submit", variant="primary", size="lg") | |
| stop = gr.Button(value="Stop", variant="secondary", size="lg") | |
| with gr.Row(): | |
| text = gr.Textbox( | |
| label="Simple Description", | |
| value="", | |
| lines=4, | |
| ) | |
| with gr.Row(): | |
| tags = gr.Textbox( | |
| label="Tags", | |
| value="", | |
| lines=2, | |
| ) | |
| with gr.Row(): | |
| preference_level = gr.Dropdown(choices=["very high", "high", "normal"], value="very high", label="Prompt Quality") | |
| with gr.Row(): | |
| clear = gr.ClearButton( | |
| components=[], | |
| variant="secondary", | |
| size="lg", | |
| ) | |
| gr.Markdown(value=DESCRIPTION) | |
| with gr.Column(variant="panel"): | |
| generated_text = gr.Textbox(label="Output", lines=20) | |
| metrics_text = gr.Textbox(label="predicted_tokens_seconds", lines=1, interactive=False) | |
| clear.add([text, tags, generated_text, metrics_text]) | |
| stream_evt = submit.click( | |
| chat, | |
| inputs=(text, tags, preference_level), | |
| outputs=(generated_text, metrics_text), | |
| queue=True | |
| ) | |
| stop.click(fn=None, inputs=None, outputs=None, cancels=[stream_evt]) | |
| demo.queue(max_size=10) | |
| demo.launch() | |
| if __name__ == "__main__": | |
| main() |