ELYZA-Diffusion / app.py
yamadamya's picture
Upload app.py
7e1b878 verified
import os
import torch
import gradio as gr
from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
MODEL_ID = os.getenv("MODEL_ID", "elyza/ELYZA-Diffusion-Instruct-1.0-Dream-7B")
DEVICE = "cpu"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16, # 計算はfp16
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
print(f"Starting CPU quant Space: DEVICE={DEVICE}, MODEL_ID={MODEL_ID}")
model = AutoModel.from_pretrained(
MODEL_ID,
quantization_config=bnb_config,
device_map={"": DEVICE},
trust_remote_code=True,
).eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
@torch.no_grad()
def generate(prompt, steps, max_new_tokens, temperature, top_p, alg_temp):
prompt = (prompt or "").strip()
if not prompt:
return "プロンプトを入力してください。"
steps = int(max(4, min(int(steps), 64)))
max_new_tokens = int(max(16, min(int(max_new_tokens), 128)))
messages = [{"role": "user", "content": prompt}]
inputs = tokenizer.apply_chat_template(
messages,
return_tensors="pt",
return_dict=True,
add_generation_prompt=True,
)
input_ids = inputs.input_ids.to(DEVICE)
attention_mask = inputs.attention_mask.to(DEVICE)
out = model.diffusion_generate(
input_ids,
attention_mask=attention_mask,
steps=steps,
max_new_tokens=max_new_tokens,
temperature=float(temperature),
top_p=float(top_p),
alg="entropy",
alg_temp=float(alg_temp),
)
return tokenizer.decode(out.sequences[0][input_ids.size(1):], skip_special_tokens=True)
with gr.Blocks() as demo:
gr.Markdown("## ELYZA Diffusion LLM (CPU 4bit quant)")
prompt = gr.Textbox(label="Prompt", lines=6, value="拡散言語モデルについて教えて")
with gr.Row():
steps = gr.Slider(4, 64, value=16, step=1, label="steps")
max_new_tokens = gr.Slider(16, 128, value=96, step=1, label="max_new_tokens")
with gr.Row():
temperature = gr.Slider(0.1, 1.5, value=0.8, step=0.05, label="temperature")
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="top_p")
alg_temp = gr.Slider(0.1, 1.5, value=0.8, step=0.05, label="alg_temp")
run = gr.Button("Generate")
out = gr.Textbox(label="Output", lines=14)
run.click(generate, [prompt, steps, max_new_tokens, temperature, top_p, alg_temp], out)
demo.queue(max_size=8)
if __name__ == "__main__":
demo.launch()