Spaces:
Running
Running
| """ | |
| code-1b-chat-v2 Inference API | |
| OpenAI-compatible /v1/chat/completions endpoint. | |
| Downloads GGUF on first startup, then serves requests. | |
| """ | |
| import os, json, time, asyncio, logging | |
| from contextlib import asynccontextmanager | |
| from typing import List, Optional, AsyncGenerator | |
| import uvicorn | |
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.responses import StreamingResponse, HTMLResponse | |
| from pydantic import BaseModel | |
| from huggingface_hub import hf_hub_download | |
| logging.basicConfig(level=logging.INFO) | |
| log = logging.getLogger("api") | |
| # ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| GGUF_REPO = "rovdetection/code-1b-chat-v2-gguf" | |
| GGUF_FILE = "code-1b-chat-v2-Q4_K_M.gguf" | |
| MODEL_PATH = "/app/model.gguf" | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| N_CTX = int(os.environ.get("N_CTX", "512")) # matches training MAX_LEN | |
| N_THREADS = int(os.environ.get("N_THREADS", "4")) | |
| N_BATCH = int(os.environ.get("N_BATCH", "512")) | |
| MODEL_ID = "code-1b-chat-v2" | |
| # Alpaca prompt template (must match SFT training format exactly) | |
| SYSTEM_DEFAULT = ( | |
| "Below is an instruction that describes a coding task. " | |
| "Write a response that appropriately completes the request." | |
| ) | |
| STOP_TOKENS = ["### Instruction:", "### Input:", "### Response:"] | |
| # ββ Global model instance ββββββββββββββββββββββββββββββββββββββββββββββ | |
| llm = None | |
| def download_model(): | |
| if os.path.exists(MODEL_PATH): | |
| size = os.path.getsize(MODEL_PATH) / 1e9 | |
| log.info(f"Model already present: {size:.2f} GB") | |
| return | |
| log.info(f"Downloading {GGUF_FILE} from {GGUF_REPO}...") | |
| path = hf_hub_download( | |
| repo_id=GGUF_REPO, | |
| filename=GGUF_FILE, | |
| token=HF_TOKEN, | |
| local_dir="/app", | |
| ) | |
| if path != MODEL_PATH: | |
| os.rename(path, MODEL_PATH) | |
| size = os.path.getsize(MODEL_PATH) / 1e9 | |
| log.info(f"Download complete: {size:.2f} GB") | |
| def load_model(): | |
| global llm | |
| from llama_cpp import Llama | |
| log.info(f"Loading model: n_ctx={N_CTX}, n_threads={N_THREADS}") | |
| llm = Llama( | |
| model_path=MODEL_PATH, | |
| n_ctx=N_CTX, | |
| n_threads=N_THREADS, | |
| n_batch=N_BATCH, | |
| verbose=False, | |
| ) | |
| log.info("Model loaded and ready.") | |
| async def lifespan(app: FastAPI): | |
| # Download + load model on startup | |
| download_model() | |
| load_model() | |
| yield | |
| # Cleanup on shutdown | |
| global llm | |
| del llm | |
| # ββ FastAPI app ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| app = FastAPI( | |
| title="Code 1B Chat API", | |
| version="1.0.0", | |
| lifespan=lifespan, | |
| ) | |
| # ββ Schemas ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class Message(BaseModel): | |
| role: str | |
| content: str | |
| class ChatRequest(BaseModel): | |
| model: Optional[str] = MODEL_ID | |
| messages: List[Message] | |
| max_tokens: Optional[int] = 512 | |
| temperature: Optional[float] = 0.7 | |
| top_p: Optional[float] = 0.95 | |
| stream: Optional[bool] = False | |
| stop: Optional[List[str]] = None | |
| # ββ Prompt builder βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_prompt(messages: List[Message]) -> str: | |
| """ | |
| Convert OpenAI-style messages to the Alpaca format used during SFT. | |
| Uses the system message if provided, falls back to default. | |
| Supports simple multi-turn by concatenating user/assistant pairs. | |
| """ | |
| system = SYSTEM_DEFAULT | |
| conversation_parts = [] | |
| for msg in messages: | |
| if msg.role == "system": | |
| system = msg.content | |
| elif msg.role == "user": | |
| conversation_parts.append(("user", msg.content)) | |
| elif msg.role == "assistant": | |
| conversation_parts.append(("assistant", msg.content)) | |
| if not conversation_parts: | |
| raise ValueError("No user message found") | |
| # Build prompt β use last user message as the instruction | |
| # For multi-turn, prepend prior turns as context | |
| prompt = f"{system}\n\n" | |
| for i, (role, content) in enumerate(conversation_parts): | |
| if role == "user": | |
| # Check if this is the last user message | |
| is_last = (i == len(conversation_parts) - 1) or \ | |
| all(r == "assistant" for r, _ in conversation_parts[i+1:]) | |
| if is_last: | |
| prompt += f"### Instruction:\n{content}\n\n### Response:\n" | |
| else: | |
| prompt += f"### Instruction:\n{content}\n\n" | |
| elif role == "assistant": | |
| prompt += f"### Response:\n{content}\n\n" | |
| return prompt | |
| # ββ Routes βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def root(): | |
| return """ | |
| <html><body style="font-family:sans-serif;max-width:600px;margin:2rem auto"> | |
| <h2>π» code-1b-chat-v2 API</h2> | |
| <p>OpenAI-compatible inference API. Model: <code>code-1b-chat-v2</code></p> | |
| <h3>Quick test</h3> | |
| <pre style="background:#f5f5f5;padding:1rem;border-radius:8px"> | |
| curl -X POST /v1/chat/completions \\ | |
| -H "Content-Type: application/json" \\ | |
| -d '{ | |
| "model": "code-1b-chat-v2", | |
| "messages": [{"role":"user","content":"Write a Python fibonacci function."}], | |
| "max_tokens": 200 | |
| }'</pre> | |
| <p><a href="/docs">π API docs</a></p> | |
| </body></html> | |
| """ | |
| def health(): | |
| return { | |
| "status": "ok" if llm is not None else "loading", | |
| "model": MODEL_ID, | |
| "model_loaded": llm is not None, | |
| } | |
| def list_models(): | |
| return { | |
| "object": "list", | |
| "data": [{ | |
| "id": MODEL_ID, | |
| "object": "model", | |
| "owned_by": "rovdetection", | |
| "permission": [], | |
| }] | |
| } | |
| async def chat_completions(req: ChatRequest): | |
| if llm is None: | |
| raise HTTPException(status_code=503, detail="Model is still loading. Try again shortly.") | |
| try: | |
| prompt = build_prompt(req.messages) | |
| except ValueError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| stop = (req.stop or []) + STOP_TOKENS | |
| created = int(time.time()) | |
| if req.stream: | |
| async def stream_generator() -> AsyncGenerator[str, None]: | |
| for chunk in llm( | |
| prompt, | |
| max_tokens=req.max_tokens, | |
| temperature=req.temperature, | |
| top_p=req.top_p, | |
| stop=stop, | |
| stream=True, | |
| ): | |
| text = chunk["choices"][0]["text"] | |
| finish = chunk["choices"][0].get("finish_reason") | |
| data = { | |
| "id": f"chatcmpl-{created}", | |
| "object": "chat.completion.chunk", | |
| "created": created, | |
| "model": MODEL_ID, | |
| "choices": [{ | |
| "index": 0, | |
| "delta": {"content": text}, | |
| "finish_reason": finish, | |
| }] | |
| } | |
| yield f"data: {json.dumps(data)}\n\n" | |
| await asyncio.sleep(0) # yield control to event loop | |
| yield "data: [DONE]\n\n" | |
| return StreamingResponse( | |
| stream_generator(), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "X-Accel-Buffering": "no", | |
| } | |
| ) | |
| # Non-streaming | |
| output = llm( | |
| prompt, | |
| max_tokens=req.max_tokens, | |
| temperature=req.temperature, | |
| top_p=req.top_p, | |
| stop=stop, | |
| ) | |
| content = output["choices"][0]["text"].strip() | |
| finish_reason = output["choices"][0]["finish_reason"] | |
| usage = output.get("usage", {}) | |
| return { | |
| "id": f"chatcmpl-{created}", | |
| "object": "chat.completion", | |
| "created": created, | |
| "model": MODEL_ID, | |
| "choices": [{ | |
| "index": 0, | |
| "message": {"role": "assistant", "content": content}, | |
| "finish_reason": finish_reason, | |
| }], | |
| "usage": { | |
| "prompt_tokens": usage.get("prompt_tokens", 0), | |
| "completion_tokens": usage.get("completion_tokens", 0), | |
| "total_tokens": usage.get("total_tokens", 0), | |
| } | |
| } | |
| # ββ Entry point ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| uvicorn.run( | |
| "app:app", | |
| host="0.0.0.0", | |
| port=7860, | |
| log_level="info", | |
| ) | |