ChatBotIME / api.py
felipecaspol's picture
Corrigiendo API para modelo fusionado
be48726
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from huggingface_hub import snapshot_download
import os
# ✅ Inicializar FastAPI
app = FastAPI()
# ✅ Definir un directorio de caché seguro
os.environ["HF_HOME"] = "/tmp/huggingface"
# ✅ Repositorio del modelo fusionado (actualizado)
HUGGING_FACE_REPO = "fcp2207/Fusion_modelo_Phi2" # ✅ Debe coincidir con el Space donde guardaste `phi2_full_model`
# ✅ Descargar el modelo fusionado desde Hugging Face
print("🔄 Descargando modelo fusionado...")
model_path = snapshot_download(repo_id=HUGGING_FACE_REPO, cache_dir=os.environ["HF_HOME"])
# ✅ Cargar el tokenizer desde el modelo fusionado
print("🔄 Cargando tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_path)
# ✅ Cargar el modelo en modo optimizado para memoria
print("🔄 Cargando modelo...")
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16, # Reduce el tamaño del modelo
device_map="auto" # Optimiza la carga en CPU/GPU automáticamente
)
model.eval() # Poner el modelo en modo inferencia
# ✅ Definir la estructura de la solicitud para la API
class InputText(BaseModel):
input_text: str
@app.get("/")
def home():
"""Endpoint de prueba para verificar que la API está activa"""
return {"message": "API de Chatbot con Phi-2 fusionado está en funcionamiento 🚀"}
@app.post("/predict/")
def predict(request: InputText):
"""Genera una respuesta basada en el input del usuario."""
inputs = tokenizer(request.input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
with torch.no_grad():
outputs = model.generate(**inputs, max_length=150)
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"response": response_text}
# ✅ Ejecución en modo local (opcional, no necesario en Hugging Face)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)