ChatBotIME / api.py
felipecaspol's picture
Optimización de memoria en Hugging Face Spaces
4e469f4
raw
history blame
2.26 kB
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from huggingface_hub import hf_hub_download
import os
# ✅ Inicializar FastAPI
app = FastAPI()
# ✅ Definir un directorio de caché seguro
os.environ["HF_HOME"] = "/tmp/huggingface"
# ✅ Nombre del modelo en Hugging Face Hub
HUGGING_FACE_REPO = "fcp2207/Phi-2" # Reemplaza con tu usuario y nombre correcto del modelo en Hugging Face
MODEL_FILENAME = "phi2_finetuned.pth" # Nombre del archivo en Hugging Face
# ✅ Descargar el modelo desde Hugging Face con caché segura
model_path = hf_hub_download(
repo_id=HUGGING_FACE_REPO,
filename=MODEL_FILENAME,
cache_dir=os.environ["HF_HOME"] # Directorio seguro en Hugging Face Spaces
)
# ✅ Cargar el tokenizer
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", cache_dir=os.environ["HF_HOME"])
# ✅ Cargar el modelo en modo optimizado para memoria
model = AutoModelForCausalLM.from_pretrained(
"microsoft/phi-2",
cache_dir=os.environ["HF_HOME"],
torch_dtype=torch.float16, # Reduce el tamaño del modelo
device_map="auto" # Optimiza la carga en CPU/GPU automáticamente
)
# ✅ Cargar los pesos del modelo entrenado
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval() # Poner el modelo en modo inferencia
# ✅ Definir la estructura de la solicitud para la API
class InputText(BaseModel):
input_text: str
@app.get("/")
def home():
"""Endpoint de prueba para verificar que la API está activa"""
return {"message": "API de Chatbot con Phi-2 está en funcionamiento 🚀"}
@app.post("/predict/")
def predict(request: InputText):
"""Genera una respuesta basada en el input del usuario."""
inputs = tokenizer(request.input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
with torch.no_grad():
outputs = model.generate(**inputs, max_length=150)
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"response": response_text}
# ✅ Ejecución en modo local (opcional, no necesario en Hugging Face)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)