from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from pydantic import BaseModel import asyncio import json import time from mlx_lm import load, generate app = FastAPI( title="API de generación de texto (Concurrente, Async, Stream y estilo Ollama)", description="API que utiliza el modelo gemma-3-text-4b-it-4bit, forzando el uso de CPU para evitar problemas con el backend Metal, y transmite la respuesta en streaming con formato estilo Ollama. Se controlan múltiples solicitudes simultáneas mediante un semáforo.", version="1.0.0" ) # Cargar el modelo y tokenizer utilizando la CPU MODEL_NAME = "mlx-community/gemma-3-text-4b-it-4bit" try: model, tokenizer = load(MODEL_NAME) # Forzar uso de CPU except Exception as e: raise RuntimeError(f"No se pudo cargar el modelo {MODEL_NAME}: {e}") # Límite máximo de tokens por chunk (ajústalo según tus necesidades) MAX_TOKENS_PER_CHUNK = 200 # Definir el número máximo de solicitudes concurrentes que utilizarán el modelo CONCURRENT_REQUESTS = 5 semaphore = asyncio.Semaphore(CONCURRENT_REQUESTS) class PromptRequest(BaseModel): prompt: str def split_response_into_chunks(response: str, tokenizer, max_tokens: int): """ Divide la respuesta en fragmentos (chunks) según la cantidad de tokens. Si la cantidad de tokens es menor o igual a max_tokens, se devuelve una lista con el texto completo. """ tokens = tokenizer.encode(response, add_special_tokens=False) if len(tokens) <= max_tokens: return [response] # Dividir los tokens en grupos de max_tokens token_chunks = [tokens[i:i+max_tokens] for i in range(0, len(tokens), max_tokens)] # Decodificar cada grupo de tokens en texto text_chunks = [tokenizer.decode(chunk, skip_special_tokens=True) for chunk in token_chunks] return text_chunks @app.post("/generate-stream-ollama") async def generate_text_stream_ollama(request: PromptRequest): prompt = request.prompt # Si el tokenizer posee plantilla de chat, la aplicamos if tokenizer.chat_template is not None: messages = [{"role": "user", "content": prompt}] prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) # Usar el semáforo para limitar el número de solicitudes concurrentes async with semaphore: loop = asyncio.get_running_loop() try: response = await loop.run_in_executor(None, generate, model, tokenizer, prompt, True) except Exception as e: raise HTTPException(status_code=500, detail=f"Error generando la respuesta: {e}") # Dividir la respuesta en chunks basados en tokens chunks = split_response_into_chunks(response, tokenizer, MAX_TOKENS_PER_CHUNK) async def stream_response(): # Enviar cada chunk en formato JSON siguiendo un estilo similar al de Ollama for chunk in chunks: event = { "id": f"stream-{int(time.time()*1000)}", # ID único para el fragmento "object": "chat.completion.chunk", "created": int(time.time()), "model": MODEL_NAME, "choices": [ { "delta": {"content": chunk}, "index": 0, "finish_reason": None } ] } # Cada mensaje se envía con el prefijo "data:" y dos saltos de línea, como en SSE yield f"data: {json.dumps(event)}\n\n" await asyncio.sleep(0.05) # Indicar el fin del stream yield "data: [DONE]\n\n" # Se utiliza "text/event-stream" para que el cliente interprete la respuesta como SSE return StreamingResponse(stream_response(), media_type="text/event-stream") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)