Spaces:
Paused
Paused
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel | |
| import asyncio | |
| import json | |
| import time | |
| from mlx_lm import load, generate | |
| app = FastAPI( | |
| title="API de generación de texto (Concurrente, Async, Stream y estilo Ollama)", | |
| description="API que utiliza el modelo gemma-3-text-4b-it-4bit, forzando el uso de CPU para evitar problemas con el backend Metal, y transmite la respuesta en streaming con formato estilo Ollama. Se controlan múltiples solicitudes simultáneas mediante un semáforo.", | |
| version="1.0.0" | |
| ) | |
| # Cargar el modelo y tokenizer utilizando la CPU | |
| MODEL_NAME = "mlx-community/gemma-3-text-4b-it-4bit" | |
| try: | |
| model, tokenizer = load(MODEL_NAME) # Forzar uso de CPU | |
| except Exception as e: | |
| raise RuntimeError(f"No se pudo cargar el modelo {MODEL_NAME}: {e}") | |
| # Límite máximo de tokens por chunk (ajústalo según tus necesidades) | |
| MAX_TOKENS_PER_CHUNK = 200 | |
| # Definir el número máximo de solicitudes concurrentes que utilizarán el modelo | |
| CONCURRENT_REQUESTS = 5 | |
| semaphore = asyncio.Semaphore(CONCURRENT_REQUESTS) | |
| class PromptRequest(BaseModel): | |
| prompt: str | |
| def split_response_into_chunks(response: str, tokenizer, max_tokens: int): | |
| """ | |
| Divide la respuesta en fragmentos (chunks) según la cantidad de tokens. | |
| Si la cantidad de tokens es menor o igual a max_tokens, se devuelve una lista con el texto completo. | |
| """ | |
| tokens = tokenizer.encode(response, add_special_tokens=False) | |
| if len(tokens) <= max_tokens: | |
| return [response] | |
| # Dividir los tokens en grupos de max_tokens | |
| token_chunks = [tokens[i:i+max_tokens] for i in range(0, len(tokens), max_tokens)] | |
| # Decodificar cada grupo de tokens en texto | |
| text_chunks = [tokenizer.decode(chunk, skip_special_tokens=True) for chunk in token_chunks] | |
| return text_chunks | |
| async def generate_text_stream_ollama(request: PromptRequest): | |
| prompt = request.prompt | |
| # Si el tokenizer posee plantilla de chat, la aplicamos | |
| if tokenizer.chat_template is not None: | |
| messages = [{"role": "user", "content": prompt}] | |
| prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) | |
| # Usar el semáforo para limitar el número de solicitudes concurrentes | |
| async with semaphore: | |
| loop = asyncio.get_running_loop() | |
| try: | |
| response = await loop.run_in_executor(None, generate, model, tokenizer, prompt, True) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error generando la respuesta: {e}") | |
| # Dividir la respuesta en chunks basados en tokens | |
| chunks = split_response_into_chunks(response, tokenizer, MAX_TOKENS_PER_CHUNK) | |
| async def stream_response(): | |
| # Enviar cada chunk en formato JSON siguiendo un estilo similar al de Ollama | |
| for chunk in chunks: | |
| event = { | |
| "id": f"stream-{int(time.time()*1000)}", # ID único para el fragmento | |
| "object": "chat.completion.chunk", | |
| "created": int(time.time()), | |
| "model": MODEL_NAME, | |
| "choices": [ | |
| { | |
| "delta": {"content": chunk}, | |
| "index": 0, | |
| "finish_reason": None | |
| } | |
| ] | |
| } | |
| # Cada mensaje se envía con el prefijo "data:" y dos saltos de línea, como en SSE | |
| yield f"data: {json.dumps(event)}\n\n" | |
| await asyncio.sleep(0.05) | |
| # Indicar el fin del stream | |
| yield "data: [DONE]\n\n" | |
| # Se utiliza "text/event-stream" para que el cliente interprete la respuesta como SSE | |
| return StreamingResponse(stream_response(), media_type="text/event-stream") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |