from fastapi import FastAPI from pydantic import BaseModel from transformers import AutoModelForCausalLM, AutoTokenizer import torch from huggingface_hub import hf_hub_download import os # ✅ Inicializar FastAPI app = FastAPI() # ✅ Definir un directorio de caché seguro os.environ["HF_HOME"] = "/tmp/huggingface" # ✅ Nombre del modelo en Hugging Face Hub HUGGING_FACE_REPO = "fcp2207/Phi-2" # Reemplaza con tu usuario y nombre correcto del modelo en Hugging Face MODEL_FILENAME = "phi2_finetuned.pth" # Nombre del archivo en Hugging Face # ✅ Descargar el modelo desde Hugging Face con caché segura model_path = hf_hub_download( repo_id=HUGGING_FACE_REPO, filename=MODEL_FILENAME, cache_dir=os.environ["HF_HOME"] # Directorio seguro en Hugging Face Spaces ) # ✅ Cargar el tokenizer tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", cache_dir=os.environ["HF_HOME"]) # ✅ Cargar el modelo en modo optimizado para memoria model = AutoModelForCausalLM.from_pretrained( "microsoft/phi-2", cache_dir=os.environ["HF_HOME"], torch_dtype=torch.float16, # Reduce el tamaño del modelo device_map="auto" # Optimiza la carga en CPU/GPU automáticamente ) # ✅ Cargar los pesos del modelo entrenado model.load_state_dict(torch.load(model_path, map_location="cpu")) model.eval() # Poner el modelo en modo inferencia # ✅ Definir la estructura de la solicitud para la API class InputText(BaseModel): input_text: str @app.get("/") def home(): """Endpoint de prueba para verificar que la API está activa""" return {"message": "API de Chatbot con Phi-2 está en funcionamiento 🚀"} @app.post("/predict/") def predict(request: InputText): """Genera una respuesta basada en el input del usuario.""" inputs = tokenizer(request.input_text, return_tensors="pt", padding=True, truncation=True, max_length=512) with torch.no_grad(): outputs = model.generate(**inputs, max_length=150) response_text = tokenizer.decode(outputs[0], skip_special_tokens=True) return {"response": response_text} # ✅ Ejecución en modo local (opcional, no necesario en Hugging Face) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)