Spaces:
Sleeping
Sleeping
| import time | |
| import uuid | |
| import json | |
| import asyncio | |
| from typing import AsyncGenerator | |
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from sse_starlette.sse import EventSourceResponse | |
| from aetheris.api.schemas import ( | |
| ChatCompletionRequest, ChatCompletionResponse, ChatCompletionChunk, | |
| ChatCompletionChoice, ChatMessage, ChatCompletionChunkChoice, ChatCompletionChunkDelta, | |
| CompletionRequest, CompletionResponse, CompletionChoice, | |
| ModelList, ModelCard | |
| ) | |
| from aetheris.inference import InferenceEngine | |
| app = FastAPI(title="Aetheris API", version="0.1.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global engine instance | |
| engine: InferenceEngine = None | |
| def get_engine(): | |
| global engine | |
| if engine is None: | |
| # Defaults, ideally loaded from config/env | |
| engine = InferenceEngine() | |
| return engine | |
| async def startup_event(): | |
| get_engine() | |
| async def root(): | |
| return {"status": "running", "message": "Aetheris API is active. Use /v1/chat/completions for inference."} | |
| async def list_models(): | |
| return ModelList(data=[ModelCard(id="aetheris-hybrid-mamba-moe")]) | |
| async def chat_completions(request: ChatCompletionRequest): | |
| engine = get_engine() | |
| # Simple prompt construction from messages | |
| prompt = "" | |
| for msg in request.messages: | |
| prompt += f"{msg.role}: {msg.content}\n" | |
| prompt += "assistant: " | |
| request_id = f"chatcmpl-{uuid.uuid4()}" | |
| created_time = int(time.time()) | |
| if request.stream: | |
| async def event_generator(): | |
| yield json.dumps(ChatCompletionChunk( | |
| id=request_id, | |
| created=created_time, | |
| model=request.model, | |
| choices=[ChatCompletionChunkChoice( | |
| index=0, | |
| delta=ChatCompletionChunkDelta(role="assistant"), | |
| finish_reason=None | |
| )] | |
| ).model_dump()) | |
| # Offload synchronous generation to a thread to avoid blocking the event loop | |
| queue = asyncio.Queue() | |
| loop = asyncio.get_running_loop() | |
| import threading | |
| stop_event = threading.Event() | |
| def producer(): | |
| try: | |
| # Run the synchronous generator | |
| for token in engine.generate( | |
| prompt=prompt, | |
| max_new_tokens=request.max_tokens or 100, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| repetition_penalty=1.0 + request.frequency_penalty, | |
| stream=True | |
| ): | |
| if stop_event.is_set(): | |
| break | |
| # Schedule the put() coroutine on the main loop | |
| asyncio.run_coroutine_threadsafe(queue.put(token), loop) | |
| except Exception as e: | |
| print(f"Generation error: {e}") | |
| finally: | |
| # Signal done | |
| asyncio.run_coroutine_threadsafe(queue.put(None), loop) | |
| thread = threading.Thread(target=producer, daemon=True) | |
| thread.start() | |
| try: | |
| while True: | |
| token = await queue.get() | |
| if token is None: | |
| break | |
| yield json.dumps(ChatCompletionChunk( | |
| id=request_id, | |
| created=created_time, | |
| model=request.model, | |
| choices=[ChatCompletionChunkChoice( | |
| index=0, | |
| delta=ChatCompletionChunkDelta(content=token), | |
| finish_reason=None | |
| )] | |
| ).model_dump()) | |
| yield json.dumps(ChatCompletionChunk( | |
| id=request_id, | |
| created=created_time, | |
| model=request.model, | |
| choices=[ChatCompletionChunkChoice( | |
| index=0, | |
| delta=ChatCompletionChunkDelta(), | |
| finish_reason="stop" | |
| )] | |
| ).model_dump()) | |
| yield "[DONE]" | |
| finally: | |
| stop_event.set() | |
| return EventSourceResponse(event_generator()) | |
| else: | |
| generated_text = engine.generate_full( | |
| prompt=prompt, | |
| max_new_tokens=request.max_tokens or 100, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| repetition_penalty=1.0 + request.frequency_penalty | |
| ) | |
| return ChatCompletionResponse( | |
| id=request_id, | |
| created=created_time, | |
| model=request.model, | |
| choices=[ChatCompletionChoice( | |
| index=0, | |
| message=ChatMessage(role="assistant", content=generated_text), | |
| finish_reason="stop" | |
| )], | |
| usage={"prompt_tokens": len(prompt), "completion_tokens": len(generated_text), "total_tokens": len(prompt) + len(generated_text)} # Approximated | |
| ) | |
| async def completions(request: CompletionRequest): | |
| engine = get_engine() | |
| prompt = request.prompt | |
| if isinstance(prompt, list): | |
| prompt = prompt[0] # Handle single prompt for now | |
| request_id = f"cmpl-{uuid.uuid4()}" | |
| created_time = int(time.time()) | |
| if request.stream: | |
| # Streaming for completions not fully implemented to match OpenAI exactly in this demo, | |
| # but logic is similar to chat. | |
| # For simplicity, returning non-streaming for now or basic stream. | |
| pass # TODO: Implement streaming for completions | |
| generated_text = engine.generate_full( | |
| prompt=prompt, | |
| max_new_tokens=request.max_tokens or 16, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| repetition_penalty=1.0 + request.frequency_penalty | |
| ) | |
| return CompletionResponse( | |
| id=request_id, | |
| created=created_time, | |
| model=request.model, | |
| choices=[CompletionChoice( | |
| text=generated_text, | |
| index=0, | |
| logprobs=None, | |
| finish_reason="length" # or stop | |
| )], | |
| usage={"prompt_tokens": len(prompt), "completion_tokens": len(generated_text), "total_tokens": len(prompt) + len(generated_text)} | |
| ) | |