File size: 3,542 Bytes
91c131d
 
 
 
 
 
 
 
 
a162a4c
91c131d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efbd8ca
91c131d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81d8b07
 
 
 
 
91c131d
81d8b07
91c131d
 
 
 
81d8b07
91c131d
 
 
81d8b07
 
91c131d
cb9558e
268ccb1
81d8b07
91c131d
 
 
 
81d8b07
 
 
a162a4c
efbd8ca
91c131d
 
7111032
 
91c131d
 
 
 
 
7111032
91c131d
 
 
7111032
 
 
 
91c131d
 
 
 
 
 
81d8b07
 
 
 
 
 
18e4f00
91c131d
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""Model loading utilities with Streamlit caching."""

import os
import streamlit as st
from pathlib import Path
from typing import Optional, Dict, Any, Tuple
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    AutoConfig,
    pipeline,
    Pipeline,
)
import torch

# Obtém token do Hugging Face (disponível automaticamente no Spaces)
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")

# Define o diretório de cache dentro do projeto
PROJECT_ROOT = Path(__file__).parent.parent.parent
MODELS_CACHE_DIR = PROJECT_ROOT / "models"
MODELS_CACHE_DIR.mkdir(exist_ok=True)


@st.cache_resource
def load_model(
    model_name: str,
    **args,
) -> Tuple[Pipeline, Dict[str, Any]]:
    """
    Carrega um modelo do Hugging Face com cache do Streamlit.
    
    Args:
        model_name: Nome do modelo no Hugging Face (ex: 'microsoft/DialoGPT-medium')
        
    Returns:
        Tupla contendo (pipeline, model_info)
    """
    try:
        # Detecta dispositivo disponível
        has_cuda = torch.cuda.is_available()
        
        # Ajusta device_map: se não há GPU ou device_map é "auto" sem GPU, usa None
        if has_cuda:
            args['device_map'] = "cuda"
        else:
            args['device_map'] = 'cpu'
            
        # Configurações de quantização
        #model_kwargs = { "torch_dtype": torch_dtype,}
        
        # Só adiciona device_map se não for None
        
        # Carrega tokenizer e modelo usando cache do projeto
        args['cache_dir'] = str(MODELS_CACHE_DIR)
        
        # Prepara kwargs com token de autenticação se disponível
        if HF_TOKEN:
            args["token"] = HF_TOKEN
            
        
        # Carrega tokenizer (pode ser de um repositório diferente)
        tokenizer_model_name = args.pop('tokenizer',model_name)
        tokenizer = AutoTokenizer.from_pretrained( tokenizer_model_name )
        
        # Adiciona pad_token se não existir
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        args['trust_remote_code'] = True

        # Carrega modelo usando a config carregada
        model = AutoModelForCausalLM.from_pretrained( model_name, **args)
        
        # Move modelo para CPU se não há GPU e device_map não foi usado
        #if device_map is None and not has_cuda:
        #    model = model.to("cpu")
        
        # Cria pipeline
        pipeline_kwargs = {
            "model": model,
            "tokenizer": tokenizer,
            'device_map': args["device_map"]
        }
        
        # Só adiciona device ao pipeline se não usar device_map no modelo
        #if device_map is None:
        #    pipeline_kwargs["device"] = 0 if has_cuda else -1
        #else:
        #    pipeline_kwargs["device_map"] = device_map
        
        pipe = pipeline("text-generation", **pipeline_kwargs)
        
        # Informações do modelo
        model_info = {
            "model_name": model_name,
            "model_revision": args.get('revision','main'),
            "tokenizer_name": tokenizer_model_name,
            #"tokenizer_revision": tokenizer_revision,
            "device": str(model.device),
            "dtype": str(model.dtype),
            #"quantized": load_in_8bit or load_in_4bit,
            "cache_dir": args.get('cache_dir'),
        }
        
        return pipe, model_info
        
    except Exception as e:
        st.error(f"Erro ao carregar modelo {model_name}: {str(e)}")
        raise