simple-chat / app.py
alex4cip's picture
fix: Add fallback for model loading with better error handling
884298e
raw
history blame
10.4 kB
"""
Hugging Face LLM Chatbot with Gradio
Using transformers library to run models locally
"""
import os
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Get HF token from environment (Spaces uses Secrets, local uses .env)
HF_TOKEN = os.getenv("HF_TOKEN", None)
# Check device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Available models (optimized for local execution)
MODELS = {
"microsoft/DialoGPT-small": {
"name": "DialoGPT Small (μ˜μ–΄, 빠름)",
"max_length": 80,
"language": "en",
},
"microsoft/DialoGPT-medium": {
"name": "DialoGPT Medium (μ˜μ–΄, κ³ ν’ˆμ§ˆ)",
"max_length": 100,
"language": "en",
},
"gpt2": {
"name": "GPT-2 (μ˜μ–΄, λ²”μš©)",
"max_length": 80,
"language": "en",
},
"beomi/llama-2-ko-7b": {
"name": "Llama-2-Ko 7B (ν•œκΈ€ λŒ€ν™”ν˜•, ⚠️ 14GB+ RAM ν•„μš”)",
"max_length": 150,
"language": "ko",
"warning": "이 λͺ¨λΈμ€ 14GB μ΄μƒμ˜ λ©”λͺ¨λ¦¬κ°€ ν•„μš”ν•©λ‹ˆλ‹€. HF Spaces 무료 tierμ—μ„œλŠ” λ©”λͺ¨λ¦¬ λΆ€μ‘±μœΌλ‘œ μ‹€ν–‰λ˜μ§€ μ•Šμ„ 수 μžˆμŠ΅λ‹ˆλ‹€.",
},
"kyujinpy/KoT-Llama2-7B-Chat": {
"name": "KoT-Llama2-7B-Chat (ν•œκΈ€ λŒ€ν™”, ⚠️ 14GB+ RAM ν•„μš”)",
"max_length": 150,
"language": "ko",
"warning": "이 λͺ¨λΈμ€ 14GB μ΄μƒμ˜ λ©”λͺ¨λ¦¬κ°€ ν•„μš”ν•©λ‹ˆλ‹€. HF Spaces 무료 tierμ—μ„œλŠ” λ©”λͺ¨λ¦¬ λΆ€μ‘±μœΌλ‘œ μ‹€ν–‰λ˜μ§€ μ•Šμ„ 수 μžˆμŠ΅λ‹ˆλ‹€.",
},
"beomi/KoAlpaca-Polyglot-5.8B": {
"name": "KoAlpaca 5.8B (ν•œκΈ€ λŒ€ν™”ν˜•, ⚠️ 12GB+ RAM ν•„μš”)",
"max_length": 150,
"language": "ko",
"warning": "이 λͺ¨λΈμ€ 12GB μ΄μƒμ˜ λ©”λͺ¨λ¦¬κ°€ ν•„μš”ν•©λ‹ˆλ‹€. HF Spaces 무료 tierμ—μ„œλŠ” λ©”λͺ¨λ¦¬ λΆ€μ‘±μœΌλ‘œ μ‹€ν–‰λ˜μ§€ μ•Šμ„ 수 μžˆμŠ΅λ‹ˆλ‹€.",
},
"nlpai-lab/kullm-polyglot-5.8b-v2": {
"name": "KULLM-Polyglot 5.8B (ν•œκΈ€ λŒ€ν™”, ⚠️ 12GB+ RAM ν•„μš”)",
"max_length": 150,
"language": "ko",
"warning": "이 λͺ¨λΈμ€ 12GB μ΄μƒμ˜ λ©”λͺ¨λ¦¬κ°€ ν•„μš”ν•©λ‹ˆλ‹€. HF Spaces 무료 tierμ—μ„œλŠ” λ©”λͺ¨λ¦¬ λΆ€μ‘±μœΌλ‘œ μ‹€ν–‰λ˜μ§€ μ•Šμ„ 수 μžˆμŠ΅λ‹ˆλ‹€.",
},
}
# Model cache
loaded_models = {}
loaded_tokenizers = {}
def load_model(model_name):
"""Load model and tokenizer"""
if model_name not in loaded_models:
try:
print(f"Loading model: {model_name}")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_name,
token=HF_TOKEN,
padding_side='left',
trust_remote_code=True
)
# Add pad token if missing
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load model with safetensors support
try:
model = AutoModelForCausalLM.from_pretrained(
model_name,
token=HF_TOKEN,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
trust_remote_code=True,
use_safetensors=True
)
except Exception as e:
# Fallback to default loading if safetensors fails
print(f"⚠️ Safetensors loading failed, trying default method: {e}")
model = AutoModelForCausalLM.from_pretrained(
model_name,
token=HF_TOKEN,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
trust_remote_code=True
)
model.to(device)
model.eval()
loaded_models[model_name] = model
loaded_tokenizers[model_name] = tokenizer
print(f"βœ… Model {model_name} loaded successfully")
except Exception as e:
print(f"❌ Failed to load model {model_name}: {e}")
return None, None
return loaded_models.get(model_name), loaded_tokenizers.get(model_name)
def chat_response(message, history, model_name):
"""
Generate chatbot response
Args:
message: User input
history: Chat history in Gradio format
model_name: Selected model
Returns:
Response text
"""
try:
# Load model and tokenizer
model, tokenizer = load_model(model_name)
if model is None or tokenizer is None:
return f"❌ λͺ¨λΈ '{model_name}'을 λ‘œλ“œν•  수 μ—†μŠ΅λ‹ˆλ‹€. λ‹€λ₯Έ λͺ¨λΈμ„ μ„ νƒν•΄μ£Όμ„Έμš”."
model_config = MODELS[model_name]
# Build conversation context
conversation = ""
for msg in history:
if msg["role"] == "user":
conversation += f"{msg['content']}\n"
elif msg["role"] == "assistant":
conversation += f"{msg['content']}\n"
# Add current message
conversation += f"{message}\n"
# Tokenize
inputs = tokenizer.encode(conversation, return_tensors="pt").to(device)
# Generate response
with torch.no_grad():
outputs = model.generate(
inputs,
max_new_tokens=model_config["max_length"],
temperature=0.9,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
# Decode response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Remove the input prompt from response
response = response[len(conversation):].strip()
# If empty, return a default message
if not response:
response = "I understand. Could you tell me more?"
return response
except Exception as e:
import traceback
error_msg = str(e)
error_type = type(e).__name__
print("=" * 50)
print(f"Error Type: {error_type}")
print(f"Error Message: {error_msg}")
print(f"Traceback:\n{traceback.format_exc()}")
print("=" * 50)
if "out of memory" in error_msg.lower() or "oom" in error_msg.lower():
return "❌ λ©”λͺ¨λ¦¬ λΆ€μ‘±. 더 μž‘μ€ λͺ¨λΈμ„ μ„ νƒν•˜κ±°λ‚˜ 앱을 μž¬μ‹œμž‘ν•˜μ„Έμš”."
elif "cuda" in error_msg.lower() and device == "cpu":
return "⚠️ GPU 없이 CPU둜 μ‹€ν–‰ μ€‘μž…λ‹ˆλ‹€. 응닡이 느릴 수 μžˆμŠ΅λ‹ˆλ‹€."
else:
return f"❌ 였λ₯˜: {error_type}\n{error_msg[:200]}\n\nν„°λ―Έλ„μ—μ„œ 전체 둜그λ₯Ό ν™•μΈν•˜μ„Έμš”."
# Global state
current_model = "microsoft/DialoGPT-small"
# Preload default model
print("Preloading default model...")
load_model(current_model)
# Create Gradio interface
with gr.Blocks(
title="πŸ€– Hugging Face Chatbot",
theme=gr.themes.Soft(),
) as demo:
gr.Markdown(
"""
# πŸ€– Hugging Face LLM Chatbot
**둜컬 λͺ¨λΈ μ‹€ν–‰ 방식** - API μ œν•œ μ—†μŒ!
**μ‚¬μš© 방법:**
1. λͺ¨λΈμ„ μ„ νƒν•˜μ„Έμš” (μ²˜μŒμ—λŠ” λ‘œλ”© μ‹œκ°„ ν•„μš”)
2. λ©”μ‹œμ§€λ₯Ό μž…λ ₯ν•˜κ³  λŒ€ν™”ν•˜μ„Έμš”
3. CPUμ—μ„œ μ‹€ν–‰λ˜λ―€λ‘œ 응닡이 쑰금 느릴 수 μžˆμŠ΅λ‹ˆλ‹€
**언어별 μΆ”μ²œ λͺ¨λΈ:**
- πŸ‡¬πŸ‡§ μ˜μ–΄: DialoGPT, GPT-2
- πŸ‡°πŸ‡· ν•œκΈ€: KoGPT-2, KoAlpaca (5.8BλŠ” 큰 λͺ¨λΈ, 느림)
**μž₯점:** API μ œν•œ μ—†μŒ, μ™„μ „ 무료, μ˜€ν”„λΌμΈ μž‘λ™ κ°€λŠ₯
"""
)
# Model selector
model_dropdown = gr.Dropdown(
choices=[(config["name"], model_id) for model_id, config in MODELS.items()],
value="microsoft/DialoGPT-small",
label="🎯 λͺ¨λΈ 선택",
info="λͺ¨λΈμ„ λ³€κ²½ν•˜λ©΄ μƒˆ λͺ¨λΈμ„ λ‹€μš΄λ‘œλ“œν•©λ‹ˆλ‹€ (처음 ν•œ 번만)",
)
# Warning message for model requirements
model_warning = gr.Markdown("", visible=False)
# Chat interface
chatbot = gr.ChatInterface(
fn=chat_response,
type="messages",
additional_inputs=[model_dropdown],
chatbot=gr.Chatbot(
height=500,
placeholder="λ©”μ‹œμ§€λ₯Ό μž…λ ₯ν•˜μ„Έμš”...",
type="messages",
),
textbox=gr.Textbox(
placeholder="λ©”μ‹œμ§€λ₯Ό μž…λ ₯ν•˜μ„Έμš” (μ˜μ–΄ ꢌμž₯)...",
container=False,
scale=7,
),
examples=[
["Hello! How are you?", "microsoft/DialoGPT-small"],
["Tell me a joke", "microsoft/DialoGPT-medium"],
["μ•ˆλ…•ν•˜μ„Έμš”! 였늘 날씨가 μ–΄λ•Œμš”?", "beomi/llama-2-ko-7b"],
["인곡지λŠ₯에 λŒ€ν•΄ κ°„λ‹¨νžˆ μ„€λͺ…ν•΄μ£Όμ„Έμš”.", "kyujinpy/KoT-Llama2-7B-Chat"],
],
)
# Show warning and clear chat when model changes
def on_model_change(new_model):
global current_model
current_model = new_model
# Check if model has warning
warning_text = ""
warning_visible = False
if "warning" in MODELS[new_model]:
warning_text = f"⚠️ **경고**: {MODELS[new_model]['warning']}"
warning_visible = True
# Preload new model
load_model(new_model)
# Return: empty chat history, warning text, warning visibility
return [], warning_text, gr.update(visible=warning_visible)
model_dropdown.change(
fn=on_model_change,
inputs=[model_dropdown],
outputs=[chatbot.chatbot_state, model_warning, model_warning],
)
gr.Markdown(
"""
---
**⚠️ 참고:**
- λͺ¨λΈμ€ λ‘œμ»¬μ—μ„œ μ‹€ν–‰λ©λ‹ˆλ‹€ (첫 μ‹€ν–‰ μ‹œ λ‹€μš΄λ‘œλ“œ)
- CPUμ—μ„œ μ‹€ν–‰λ˜λ―€λ‘œ GPU보닀 λŠλ¦½λ‹ˆλ‹€
- 각 λͺ¨λΈμ€ νŠΉμ • 언어에 μ΅œμ ν™”λ˜μ–΄ μžˆμŠ΅λ‹ˆλ‹€
**πŸ’Ύ λ””μŠ€ν¬ μ‚¬μš©λŸ‰:**
- DialoGPT-small: ~350MB
- DialoGPT-medium: ~800MB
- GPT-2: ~500MB
- KoGPT-2: ~500MB
- KoAlpaca-5.8B: ~12GB (큰 λͺ¨λΈ, λ©”λͺ¨λ¦¬ 8GB+ ν•„μš”)
**πŸ’‘ 팁:**
- μ˜μ–΄ λŒ€ν™”λŠ” DialoGPT μΆ”μ²œ
- ν•œκΈ€ λŒ€ν™”λŠ” KoGPT-2 μΆ”μ²œ (KoAlpacaλŠ” λ¦¬μ†ŒμŠ€ μΆ©λΆ„ν•  λ•Œλ§Œ)
- 짧은 λ¬Έμž₯으둜 λŒ€ν™”ν•˜λ©΄ 더 λ‚˜μ€ κ²°κ³Ό
- λͺ¨λΈμ΄ ν•œ 번 λ‘œλ“œλ˜λ©΄ λ‹€μ‹œ λ‹€μš΄λ‘œλ“œν•˜μ§€ μ•ŠμŠ΅λ‹ˆλ‹€
"""
)
if __name__ == "__main__":
demo.launch()