api_v1 / app.py
Kfjjdjdjdhdhd's picture
Update app.py
c4040a5 verified
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import asyncio
import json
import time
from mlx_lm import load, generate
app = FastAPI(
title="API de generación de texto (Concurrente, Async, Stream y estilo Ollama)",
description="API que utiliza el modelo gemma-3-text-4b-it-4bit, forzando el uso de CPU para evitar problemas con el backend Metal, y transmite la respuesta en streaming con formato estilo Ollama. Se controlan múltiples solicitudes simultáneas mediante un semáforo.",
version="1.0.0"
)
# Cargar el modelo y tokenizer utilizando la CPU
MODEL_NAME = "mlx-community/gemma-3-text-4b-it-4bit"
try:
model, tokenizer = load(MODEL_NAME) # Forzar uso de CPU
except Exception as e:
raise RuntimeError(f"No se pudo cargar el modelo {MODEL_NAME}: {e}")
# Límite máximo de tokens por chunk (ajústalo según tus necesidades)
MAX_TOKENS_PER_CHUNK = 200
# Definir el número máximo de solicitudes concurrentes que utilizarán el modelo
CONCURRENT_REQUESTS = 5
semaphore = asyncio.Semaphore(CONCURRENT_REQUESTS)
class PromptRequest(BaseModel):
prompt: str
def split_response_into_chunks(response: str, tokenizer, max_tokens: int):
"""
Divide la respuesta en fragmentos (chunks) según la cantidad de tokens.
Si la cantidad de tokens es menor o igual a max_tokens, se devuelve una lista con el texto completo.
"""
tokens = tokenizer.encode(response, add_special_tokens=False)
if len(tokens) <= max_tokens:
return [response]
# Dividir los tokens en grupos de max_tokens
token_chunks = [tokens[i:i+max_tokens] for i in range(0, len(tokens), max_tokens)]
# Decodificar cada grupo de tokens en texto
text_chunks = [tokenizer.decode(chunk, skip_special_tokens=True) for chunk in token_chunks]
return text_chunks
@app.post("/generate-stream-ollama")
async def generate_text_stream_ollama(request: PromptRequest):
prompt = request.prompt
# Si el tokenizer posee plantilla de chat, la aplicamos
if tokenizer.chat_template is not None:
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
# Usar el semáforo para limitar el número de solicitudes concurrentes
async with semaphore:
loop = asyncio.get_running_loop()
try:
response = await loop.run_in_executor(None, generate, model, tokenizer, prompt, True)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error generando la respuesta: {e}")
# Dividir la respuesta en chunks basados en tokens
chunks = split_response_into_chunks(response, tokenizer, MAX_TOKENS_PER_CHUNK)
async def stream_response():
# Enviar cada chunk en formato JSON siguiendo un estilo similar al de Ollama
for chunk in chunks:
event = {
"id": f"stream-{int(time.time()*1000)}", # ID único para el fragmento
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": MODEL_NAME,
"choices": [
{
"delta": {"content": chunk},
"index": 0,
"finish_reason": None
}
]
}
# Cada mensaje se envía con el prefijo "data:" y dos saltos de línea, como en SSE
yield f"data: {json.dumps(event)}\n\n"
await asyncio.sleep(0.05)
# Indicar el fin del stream
yield "data: [DONE]\n\n"
# Se utiliza "text/event-stream" para que el cliente interprete la respuesta como SSE
return StreamingResponse(stream_response(), media_type="text/event-stream")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)