"""Model loading utilities with Streamlit caching.""" import os import streamlit as st from pathlib import Path from typing import Optional, Dict, Any, Tuple from transformers import ( AutoModelForCausalLM, AutoTokenizer, AutoConfig, pipeline, Pipeline, ) import torch # Obtém token do Hugging Face (disponível automaticamente no Spaces) HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN") # Define o diretório de cache dentro do projeto PROJECT_ROOT = Path(__file__).parent.parent.parent MODELS_CACHE_DIR = PROJECT_ROOT / "models" MODELS_CACHE_DIR.mkdir(exist_ok=True) @st.cache_resource def load_model( model_name: str, **args, ) -> Tuple[Pipeline, Dict[str, Any]]: """ Carrega um modelo do Hugging Face com cache do Streamlit. Args: model_name: Nome do modelo no Hugging Face (ex: 'microsoft/DialoGPT-medium') Returns: Tupla contendo (pipeline, model_info) """ try: # Detecta dispositivo disponível has_cuda = torch.cuda.is_available() # Ajusta device_map: se não há GPU ou device_map é "auto" sem GPU, usa None if has_cuda: args['device_map'] = "cuda" else: args['device_map'] = 'cpu' # Configurações de quantização #model_kwargs = { "torch_dtype": torch_dtype,} # Só adiciona device_map se não for None # Carrega tokenizer e modelo usando cache do projeto args['cache_dir'] = str(MODELS_CACHE_DIR) # Prepara kwargs com token de autenticação se disponível if HF_TOKEN: args["token"] = HF_TOKEN # Carrega tokenizer (pode ser de um repositório diferente) tokenizer_model_name = args.pop('tokenizer',model_name) tokenizer = AutoTokenizer.from_pretrained( tokenizer_model_name ) # Adiciona pad_token se não existir if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token args['trust_remote_code'] = True # Carrega modelo usando a config carregada model = AutoModelForCausalLM.from_pretrained( model_name, **args) # Move modelo para CPU se não há GPU e device_map não foi usado #if device_map is None and not has_cuda: # model = model.to("cpu") # Cria pipeline pipeline_kwargs = { "model": model, "tokenizer": tokenizer, 'device_map': args["device_map"] } # Só adiciona device ao pipeline se não usar device_map no modelo #if device_map is None: # pipeline_kwargs["device"] = 0 if has_cuda else -1 #else: # pipeline_kwargs["device_map"] = device_map pipe = pipeline("text-generation", **pipeline_kwargs) # Informações do modelo model_info = { "model_name": model_name, "model_revision": args.get('revision','main'), "tokenizer_name": tokenizer_model_name, #"tokenizer_revision": tokenizer_revision, "device": str(model.device), "dtype": str(model.dtype), #"quantized": load_in_8bit or load_in_4bit, "cache_dir": args.get('cache_dir'), } return pipe, model_info except Exception as e: st.error(f"Erro ao carregar modelo {model_name}: {str(e)}") raise