ChatBotIME / api.py
felipecaspol's picture
Corregido error de permisos en Hugging Face (caché personalizada)
fa720ef
raw
history blame
2.04 kB
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from huggingface_hub import hf_hub_download
import os
# ✅ Inicializar FastAPI
app = FastAPI()
# ✅ Crear directorio de caché para evitar problemas de permisos
CACHE_DIR = "./cache"
os.makedirs(CACHE_DIR, exist_ok=True)
# ✅ Nombre del modelo en Hugging Face Hub
HUGGING_FACE_REPO = "fcp2207/Phi-2" # Asegúrate de que sea el nombre correcto en Hugging Face
MODEL_FILENAME = "phi2_finetuned.pth" # Nombre del archivo en Hugging Face
# ✅ Descargar el modelo desde Hugging Face (especificando caché)
model_path = hf_hub_download(
repo_id=HUGGING_FACE_REPO,
filename=MODEL_FILENAME,
cache_dir=CACHE_DIR # Ruta de caché permitida
)
# ✅ Cargar el tokenizer y el modelo base
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", cache_dir=CACHE_DIR)
model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", cache_dir=CACHE_DIR)
# ✅ Cargar los pesos del modelo
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval() # Poner el modelo en modo inferencia
# ✅ Definir la estructura de la solicitud para la API
class InputText(BaseModel):
input_text: str
@app.get("/")
def home():
"""Endpoint de prueba para verificar que la API está activa"""
return {"message": "API de Chatbot con Phi-2 está en funcionamiento 🚀"}
@app.post("/predict/")
def predict(request: InputText):
"""Genera una respuesta basada en el input del usuario."""
inputs = tokenizer(request.input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
with torch.no_grad():
outputs = model.generate(**inputs, max_length=150)
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"response": response_text}
# ✅ Ejecución en modo local (opcional, no necesario en Hugging Face)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)