Spindle-LLM / app.py
DMID23's picture
Update app.py
14346b1 verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig
import os
# import subprocess # ν˜„μž¬ μ½”λ“œμ—μ„œ μ‚¬μš©λ˜μ§€ μ•ŠμœΌλ―€λ‘œ 제거 κ°€λŠ₯
import torch
from huggingface_hub import login
# ν™˜κ²½ λ³€μˆ˜μ—μ„œ 토큰 κ°€μ Έμ˜€κΈ°
token = os.environ.get("HF_TOKEN") # 일반적으둜 "HF_TOKEN"으둜 μ„€μ •λ©λ‹ˆλ‹€.
if token:
login(token)
else:
print("HF_TOKEN ν™˜κ²½ λ³€μˆ˜κ°€ μ„€μ •λ˜μ§€ μ•Šμ•˜μŠ΅λ‹ˆλ‹€. λͺ¨λΈ λ‹€μš΄λ‘œλ“œμ— λ¬Έμ œκ°€ μžˆμ„ 수 μžˆμŠ΅λ‹ˆλ‹€.")
# ---------- STEP 1: Fine-tuned λͺ¨λΈ 정보 ----------
repo_id = "DMID23/MachineToolAgent" # λͺ¨λΈ μ €μž₯μ†Œ ID
# ---------- STEP 2: μ–‘μžν™” μ„€μ • 및 λͺ¨λΈ λ‘œλ“œ ----------
# 8bit μ–‘μžν™” μ„€μ • (CPU ν™˜κ²½μ—μ„œλ„ μ‚¬μš© κ°€λŠ₯)
# load_in_8bit=True μ˜΅μ…˜λ§ŒμœΌλ‘œλ„ BitsAndBytesConfig 객체λ₯Ό μžλ™μœΌλ‘œ μƒμ„±ν•˜μ—¬ μ μš©ν•©λ‹ˆλ‹€.
# CPUμ—μ„œλŠ” float32 -> int8 μ–‘μžν™”κ°€ 주둜 μΌμ–΄λ‚©λ‹ˆλ‹€.
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model = AutoModelForCausalLM.from_pretrained(
repo_id,
quantization_config=quantization_config, # μ–‘μžν™” μ„€μ • 적용
torch_dtype=torch.float32, # 8λΉ„νŠΈ λ‘œλ“œ μ‹œμ—λ„ λ‚΄λΆ€μ μœΌλ‘œ float32둜 μ²˜λ¦¬λ˜κ±°λ‚˜ ν˜Όν•© μ •λ°€λ„λ‘œ μž‘λ™ν•  수 μžˆμŠ΅λ‹ˆλ‹€.
# ν•˜μ§€λ§Œ μ‹€μ œ λ©”λͺ¨λ¦¬λŠ” 8λΉ„νŠΈλ§ŒνΌλ§Œ μ‚¬μš©λ©λ‹ˆλ‹€.
device_map="auto" # λͺ¨λΈμ˜ 각 λ ˆμ΄μ–΄λ₯Ό μžλ™μœΌλ‘œ 졜적의 μž₯치(CPU/GPU)에 λΆ„λ°°
# CPU만 μžˆλ‹€λ©΄ CPU둜 λ‘œλ“œλ©λ‹ˆλ‹€.
)
print("Model loaded successfully.")
# λ§Œμ•½ DMID23/MachineToolAgent μ €μž₯μ†Œμ— ν† ν¬λ‚˜μ΄μ €κ°€ μžˆλ‹€λ©΄ repo_id둜 λ°”κΎΈμ„Έμš”.
tokenizer = AutoTokenizer.from_pretrained(repo_id)
# pipe μ„€μ • μ‹œ, device=-1 (CPU) λͺ…μ‹œ
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) #
# ---------- STEP 3: Gradio ν•¨μˆ˜ μ •μ˜ ----------
# (이 뢀뢄은 λ³€κ²½ μ—†μŒ)
def generate_response(prompt, max_length=256, temperature=0.7):
# max_lengthλ₯Ό μ œν•œν•˜μ—¬ 속도λ₯Ό λΉ λ₯΄κ²Œ 함
outputs = pipe(
prompt,
max_length=max_length,
temperature=temperature,
do_sample=True,
top_p=0.9,
num_return_sequences=1,
pad_token_id=tokenizer.eos_token_id,
)
return outputs[0]["generated_text"]
# ---------- STEP 4: Gradio UI ----------
# (이 뢀뢄은 λ³€κ²½ μ—†μŒ)
with gr.Blocks() as demo:
gr.Markdown("# πŸš€ Fine-tuned Mistral-7B (CPU Optimized)")
with gr.Row():
prompt_input = gr.Textbox(label="Input Prompt", placeholder="Type your prompt here...", lines=4)
with gr.Row():
max_len_slider = gr.Slider(64, 512, value=256, step=16, label="Max Length (lower = faster)")
temp_slider = gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature")
generate_button = gr.Button("Generate")
output_box = gr.Textbox(label="Generated Output", lines=10)
generate_button.click(
fn=generate_response,
inputs=[prompt_input, max_len_slider, temp_slider],
outputs=output_box,
)
# ---------- STEP 5: Launch ----------
demo.launch()