import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig import os # import subprocess # 현재 코드에서 사용되지 않으므로 제거 가능 import torch from huggingface_hub import login # 환경 변수에서 토큰 가져오기 token = os.environ.get("HF_TOKEN") # 일반적으로 "HF_TOKEN"으로 설정됩니다. if token: login(token) else: print("HF_TOKEN 환경 변수가 설정되지 않았습니다. 모델 다운로드에 문제가 있을 수 있습니다.") # ---------- STEP 1: Fine-tuned 모델 정보 ---------- repo_id = "DMID23/MachineToolAgent" # 모델 저장소 ID # ---------- STEP 2: 양자화 설정 및 모델 로드 ---------- # 8bit 양자화 설정 (CPU 환경에서도 사용 가능) # load_in_8bit=True 옵션만으로도 BitsAndBytesConfig 객체를 자동으로 생성하여 적용합니다. # CPU에서는 float32 -> int8 양자화가 주로 일어납니다. quantization_config = BitsAndBytesConfig(load_in_8bit=True) model = AutoModelForCausalLM.from_pretrained( repo_id, quantization_config=quantization_config, # 양자화 설정 적용 torch_dtype=torch.float32, # 8비트 로드 시에도 내부적으로 float32로 처리되거나 혼합 정밀도로 작동할 수 있습니다. # 하지만 실제 메모리는 8비트만큼만 사용됩니다. device_map="auto" # 모델의 각 레이어를 자동으로 최적의 장치(CPU/GPU)에 분배 # CPU만 있다면 CPU로 로드됩니다. ) print("Model loaded successfully.") # 만약 DMID23/MachineToolAgent 저장소에 토크나이저가 있다면 repo_id로 바꾸세요. tokenizer = AutoTokenizer.from_pretrained(repo_id) # pipe 설정 시, device=-1 (CPU) 명시 pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) # # ---------- STEP 3: Gradio 함수 정의 ---------- # (이 부분은 변경 없음) def generate_response(prompt, max_length=256, temperature=0.7): # max_length를 제한하여 속도를 빠르게 함 outputs = pipe( prompt, max_length=max_length, temperature=temperature, do_sample=True, top_p=0.9, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id, ) return outputs[0]["generated_text"] # ---------- STEP 4: Gradio UI ---------- # (이 부분은 변경 없음) with gr.Blocks() as demo: gr.Markdown("# 🚀 Fine-tuned Mistral-7B (CPU Optimized)") with gr.Row(): prompt_input = gr.Textbox(label="Input Prompt", placeholder="Type your prompt here...", lines=4) with gr.Row(): max_len_slider = gr.Slider(64, 512, value=256, step=16, label="Max Length (lower = faster)") temp_slider = gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature") generate_button = gr.Button("Generate") output_box = gr.Textbox(label="Generated Output", lines=10) generate_button.click( fn=generate_response, inputs=[prompt_input, max_len_slider, temp_slider], outputs=output_box, ) # ---------- STEP 5: Launch ---------- demo.launch()