public_chat / src /backend /model_loader.py
Fazzioni's picture
Update src/backend/model_loader.py
18e4f00 verified
"""Model loading utilities with Streamlit caching."""
import os
import streamlit as st
from pathlib import Path
from typing import Optional, Dict, Any, Tuple
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
AutoConfig,
pipeline,
Pipeline,
)
import torch
# Obtém token do Hugging Face (disponível automaticamente no Spaces)
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")
# Define o diretório de cache dentro do projeto
PROJECT_ROOT = Path(__file__).parent.parent.parent
MODELS_CACHE_DIR = PROJECT_ROOT / "models"
MODELS_CACHE_DIR.mkdir(exist_ok=True)
@st.cache_resource
def load_model(
model_name: str,
**args,
) -> Tuple[Pipeline, Dict[str, Any]]:
"""
Carrega um modelo do Hugging Face com cache do Streamlit.
Args:
model_name: Nome do modelo no Hugging Face (ex: 'microsoft/DialoGPT-medium')
Returns:
Tupla contendo (pipeline, model_info)
"""
try:
# Detecta dispositivo disponível
has_cuda = torch.cuda.is_available()
# Ajusta device_map: se não há GPU ou device_map é "auto" sem GPU, usa None
if has_cuda:
args['device_map'] = "cuda"
else:
args['device_map'] = 'cpu'
# Configurações de quantização
#model_kwargs = { "torch_dtype": torch_dtype,}
# Só adiciona device_map se não for None
# Carrega tokenizer e modelo usando cache do projeto
args['cache_dir'] = str(MODELS_CACHE_DIR)
# Prepara kwargs com token de autenticação se disponível
if HF_TOKEN:
args["token"] = HF_TOKEN
# Carrega tokenizer (pode ser de um repositório diferente)
tokenizer_model_name = args.pop('tokenizer',model_name)
tokenizer = AutoTokenizer.from_pretrained( tokenizer_model_name )
# Adiciona pad_token se não existir
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
args['trust_remote_code'] = True
# Carrega modelo usando a config carregada
model = AutoModelForCausalLM.from_pretrained( model_name, **args)
# Move modelo para CPU se não há GPU e device_map não foi usado
#if device_map is None and not has_cuda:
# model = model.to("cpu")
# Cria pipeline
pipeline_kwargs = {
"model": model,
"tokenizer": tokenizer,
'device_map': args["device_map"]
}
# Só adiciona device ao pipeline se não usar device_map no modelo
#if device_map is None:
# pipeline_kwargs["device"] = 0 if has_cuda else -1
#else:
# pipeline_kwargs["device_map"] = device_map
pipe = pipeline("text-generation", **pipeline_kwargs)
# Informações do modelo
model_info = {
"model_name": model_name,
"model_revision": args.get('revision','main'),
"tokenizer_name": tokenizer_model_name,
#"tokenizer_revision": tokenizer_revision,
"device": str(model.device),
"dtype": str(model.dtype),
#"quantized": load_in_8bit or load_in_4bit,
"cache_dir": args.get('cache_dir'),
}
return pipe, model_info
except Exception as e:
st.error(f"Erro ao carregar modelo {model_name}: {str(e)}")
raise