|
|
""" |
|
|
Kirim OSS Safeguard R1 10B - FastAPI Server |
|
|
RESTful API server for model inference with safety controls |
|
|
""" |
|
|
|
|
|
from fastapi import FastAPI, HTTPException, Header, Depends |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.responses import StreamingResponse |
|
|
from pydantic import BaseModel, Field |
|
|
from typing import Optional, List, Dict |
|
|
import asyncio |
|
|
import json |
|
|
import time |
|
|
import uuid |
|
|
from datetime import datetime |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="Kirim OSS Safeguard API", |
|
|
description="API for Kirim OSS Safeguard R1 10B model with safety controls", |
|
|
version="1.0.0" |
|
|
) |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
class GenerateRequest(BaseModel): |
|
|
prompt: str = Field(..., description="Input prompt for generation") |
|
|
max_tokens: int = Field(512, ge=1, le=8192, description="Maximum tokens to generate") |
|
|
temperature: float = Field(0.7, ge=0.0, le=2.0, description="Sampling temperature") |
|
|
top_p: float = Field(0.9, ge=0.0, le=1.0, description="Nucleus sampling parameter") |
|
|
top_k: int = Field(50, ge=0, le=100, description="Top-k sampling parameter") |
|
|
repetition_penalty: float = Field(1.1, ge=1.0, le=2.0, description="Repetition penalty") |
|
|
stop_sequences: Optional[List[str]] = Field(None, description="Stop sequences") |
|
|
stream: bool = Field(False, description="Stream the response") |
|
|
safety_mode: str = Field("moderate", description="Safety mode: strict, moderate, lenient") |
|
|
|
|
|
|
|
|
class ChatMessage(BaseModel): |
|
|
role: str = Field(..., description="Message role: system, user, or assistant") |
|
|
content: str = Field(..., description="Message content") |
|
|
|
|
|
|
|
|
class ChatRequest(BaseModel): |
|
|
messages: List[ChatMessage] = Field(..., description="List of chat messages") |
|
|
max_tokens: int = Field(512, ge=1, le=8192) |
|
|
temperature: float = Field(0.7, ge=0.0, le=2.0) |
|
|
top_p: float = Field(0.9, ge=0.0, le=1.0) |
|
|
stream: bool = Field(False, description="Stream the response") |
|
|
safety_mode: str = Field("moderate", description="Safety mode") |
|
|
|
|
|
|
|
|
class GenerateResponse(BaseModel): |
|
|
id: str = Field(..., description="Unique response ID") |
|
|
object: str = Field("text_completion", description="Object type") |
|
|
created: int = Field(..., description="Unix timestamp") |
|
|
model: str = Field(..., description="Model identifier") |
|
|
choices: List[Dict] = Field(..., description="Generated choices") |
|
|
usage: Dict = Field(..., description="Token usage statistics") |
|
|
safety: Dict = Field(..., description="Safety check results") |
|
|
|
|
|
|
|
|
class ChatResponse(BaseModel): |
|
|
id: str = Field(..., description="Unique response ID") |
|
|
object: str = Field("chat.completion", description="Object type") |
|
|
created: int = Field(..., description="Unix timestamp") |
|
|
model: str = Field(..., description="Model identifier") |
|
|
choices: List[Dict] = Field(..., description="Generated choices") |
|
|
usage: Dict = Field(..., description="Token usage statistics") |
|
|
safety: Dict = Field(..., description="Safety check results") |
|
|
|
|
|
|
|
|
class ModelInfo(BaseModel): |
|
|
id: str |
|
|
object: str = "model" |
|
|
created: int |
|
|
owned_by: str |
|
|
capabilities: List[str] |
|
|
max_tokens: int |
|
|
safety_features: List[str] |
|
|
|
|
|
|
|
|
class HealthResponse(BaseModel): |
|
|
status: str |
|
|
model_loaded: bool |
|
|
version: str |
|
|
timestamp: str |
|
|
|
|
|
|
|
|
|
|
|
class AppState: |
|
|
def __init__(self): |
|
|
self.model = None |
|
|
self.safety = None |
|
|
self.safety_wrapper = None |
|
|
self.request_count = 0 |
|
|
self.start_time = time.time() |
|
|
|
|
|
|
|
|
state = AppState() |
|
|
|
|
|
|
|
|
|
|
|
async def verify_api_key(x_api_key: Optional[str] = Header(None)): |
|
|
"""Verify API key""" |
|
|
if not x_api_key: |
|
|
|
|
|
|
|
|
pass |
|
|
|
|
|
return x_api_key |
|
|
|
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
"""Initialize model on startup""" |
|
|
print("Loading model...") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Model loaded successfully!") |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/health", response_model=HealthResponse) |
|
|
async def health_check(): |
|
|
"""Health check endpoint""" |
|
|
return HealthResponse( |
|
|
status="healthy" if state.model else "loading", |
|
|
model_loaded=state.model is not None, |
|
|
version="1.0.0", |
|
|
timestamp=datetime.now().isoformat() |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/v1/models/{model_id}", response_model=ModelInfo) |
|
|
async def get_model_info(model_id: str): |
|
|
"""Get model information""" |
|
|
if model_id != "kirim-oss-safeguard-r1-10b": |
|
|
raise HTTPException(status_code=404, detail="Model not found") |
|
|
|
|
|
return ModelInfo( |
|
|
id="kirim-oss-safeguard-r1-10b", |
|
|
created=int(time.time()), |
|
|
owned_by="kirim-ai", |
|
|
capabilities=["text-generation", "chat", "safety-filtering"], |
|
|
max_tokens=8192, |
|
|
safety_features=[ |
|
|
"hate_speech_detection", |
|
|
"violence_detection", |
|
|
"sexual_content_filtering", |
|
|
"illegal_activity_detection", |
|
|
"pii_redaction" |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/v1/models") |
|
|
async def list_models(): |
|
|
"""List available models""" |
|
|
return { |
|
|
"object": "list", |
|
|
"data": [ |
|
|
{ |
|
|
"id": "kirim-oss-safeguard-r1-10b", |
|
|
"object": "model", |
|
|
"created": int(time.time()), |
|
|
"owned_by": "kirim-ai" |
|
|
} |
|
|
] |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
@app.post("/v1/completions", response_model=GenerateResponse) |
|
|
async def generate_completion( |
|
|
request: GenerateRequest, |
|
|
api_key: str = Depends(verify_api_key) |
|
|
): |
|
|
"""Generate text completion""" |
|
|
if not state.model: |
|
|
raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
|
|
|
state.request_count += 1 |
|
|
request_id = f"cmpl-{uuid.uuid4().hex[:24]}" |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
response_text = "This is a demo response. In production, this would be generated by the model." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
filtered_response = response_text |
|
|
|
|
|
return GenerateResponse( |
|
|
id=request_id, |
|
|
object="text_completion", |
|
|
created=int(time.time()), |
|
|
model="kirim-oss-safeguard-r1-10b", |
|
|
choices=[ |
|
|
{ |
|
|
"text": filtered_response, |
|
|
"index": 0, |
|
|
"logprobs": None, |
|
|
"finish_reason": "stop" |
|
|
} |
|
|
], |
|
|
usage={ |
|
|
"prompt_tokens": len(request.prompt.split()), |
|
|
"completion_tokens": len(filtered_response.split()), |
|
|
"total_tokens": len(request.prompt.split()) + len(filtered_response.split()) |
|
|
}, |
|
|
safety={ |
|
|
"input_safe": True, |
|
|
"output_safe": True, |
|
|
"categories_flagged": [] |
|
|
} |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
|
|
|
@app.post("/v1/chat/completions", response_model=ChatResponse) |
|
|
async def chat_completion( |
|
|
request: ChatRequest, |
|
|
api_key: str = Depends(verify_api_key) |
|
|
): |
|
|
"""Generate chat completion""" |
|
|
if not state.model: |
|
|
raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
|
|
|
state.request_count += 1 |
|
|
request_id = f"chatcmpl-{uuid.uuid4().hex[:24]}" |
|
|
|
|
|
try: |
|
|
|
|
|
messages = [{"role": m.role, "content": m.content} for m in request.messages] |
|
|
|
|
|
|
|
|
user_messages = [m for m in messages if m["role"] == "user"] |
|
|
if user_messages: |
|
|
last_user_message = user_messages[-1]["content"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
response_text = "This is a demo chat response. In production, this would be generated by the model." |
|
|
|
|
|
|
|
|
|
|
|
filtered_response = response_text |
|
|
|
|
|
return ChatResponse( |
|
|
id=request_id, |
|
|
object="chat.completion", |
|
|
created=int(time.time()), |
|
|
model="kirim-oss-safeguard-r1-10b", |
|
|
choices=[ |
|
|
{ |
|
|
"index": 0, |
|
|
"message": { |
|
|
"role": "assistant", |
|
|
"content": filtered_response |
|
|
}, |
|
|
"finish_reason": "stop" |
|
|
} |
|
|
], |
|
|
usage={ |
|
|
"prompt_tokens": sum(len(m.content.split()) for m in request.messages), |
|
|
"completion_tokens": len(filtered_response.split()), |
|
|
"total_tokens": sum(len(m.content.split()) for m in request.messages) + len(filtered_response.split()) |
|
|
}, |
|
|
safety={ |
|
|
"input_safe": True, |
|
|
"output_safe": True, |
|
|
"categories_flagged": [] |
|
|
} |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/v1/stats") |
|
|
async def get_stats(): |
|
|
"""Get API statistics""" |
|
|
uptime = time.time() - state.start_time |
|
|
return { |
|
|
"request_count": state.request_count, |
|
|
"uptime_seconds": uptime, |
|
|
"model_loaded": state.model is not None, |
|
|
"version": "1.0.0" |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
"""Root endpoint""" |
|
|
return { |
|
|
"message": "Kirim OSS Safeguard R1 10B API", |
|
|
"version": "1.0.0", |
|
|
"endpoints": { |
|
|
"health": "/health", |
|
|
"models": "/v1/models", |
|
|
"completions": "/v1/completions", |
|
|
"chat": "/v1/chat/completions", |
|
|
"stats": "/v1/stats" |
|
|
}, |
|
|
"documentation": "/docs" |
|
|
} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
|
|
|
uvicorn.run( |
|
|
app, |
|
|
host="0.0.0.0", |
|
|
port=8000, |
|
|
log_level="info" |
|
|
) |