ChatBotIME / api.py
felipecaspol's picture
Corrección final: Se usa HF_HOME como caché segura en Hugging Face
cf00782
raw
history blame
2.1 kB
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from huggingface_hub import hf_hub_download
import os
# ✅ Inicializar FastAPI
app = FastAPI()
# ✅ Definir un directorio de caché seguro
os.environ["HF_HOME"] = "/tmp/huggingface"
# ✅ Nombre del modelo en Hugging Face Hub
HUGGING_FACE_REPO = "fcp2207/Phi-2" # Reemplaza con tu usuario y nombre correcto del modelo en Hugging Face
MODEL_FILENAME = "phi2_finetuned.pth" # Nombre del archivo en Hugging Face
# ✅ Descargar el modelo desde Hugging Face (usando la caché segura)
model_path = hf_hub_download(
repo_id=HUGGING_FACE_REPO,
filename=MODEL_FILENAME,
cache_dir=os.environ["HF_HOME"] # Directorio seguro en Hugging Face Spaces
)
# ✅ Cargar el tokenizer y el modelo base desde Hugging Face
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", cache_dir=os.environ["HF_HOME"])
model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", cache_dir=os.environ["HF_HOME"])
# ✅ Cargar los pesos del modelo fine-tuned
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval() # Poner el modelo en modo inferencia
# ✅ Definir la estructura de la solicitud para la API
class InputText(BaseModel):
input_text: str
@app.get("/")
def home():
"""Endpoint de prueba para verificar que la API está activa"""
return {"message": "API de Chatbot con Phi-2 está en funcionamiento 🚀"}
@app.post("/predict/")
def predict(request: InputText):
"""Genera una respuesta basada en el input del usuario."""
inputs = tokenizer(request.input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
with torch.no_grad():
outputs = model.generate(**inputs, max_length=150)
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"response": response_text}
# ✅ Ejecución en modo local (opcional, no necesario en Hugging Face)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)