|
|
|
|
|
import uvicorn |
|
|
from fastapi import FastAPI, HTTPException, Request |
|
|
import requests |
|
|
from pydantic import BaseModel, Field |
|
|
from typing import Optional, List, Dict, Any, Literal |
|
|
import json |
|
|
import time |
|
|
import logging |
|
|
import sys |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
|
handlers=[logging.StreamHandler(sys.stdout)] |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
app = FastAPI(title="OpenAI-Compatible Chat API", |
|
|
description="A FastAPI application that provides an OpenAI-compatible interface") |
|
|
|
|
|
|
|
|
class Message(BaseModel): |
|
|
role: str |
|
|
content: str |
|
|
name: Optional[str] = None |
|
|
|
|
|
class ChatCompletionRequest(BaseModel): |
|
|
model: str = "granite-3-2-8b-instruct" |
|
|
messages: List[Message] |
|
|
temperature: Optional[float] = 0.7 |
|
|
top_p: Optional[float] = 0.9 |
|
|
max_tokens: Optional[int] = 2048 |
|
|
stream: Optional[bool] = False |
|
|
|
|
|
class ChatCompletionChoice(BaseModel): |
|
|
index: int |
|
|
message: Message |
|
|
finish_reason: str = "stop" |
|
|
|
|
|
class Usage(BaseModel): |
|
|
prompt_tokens: int |
|
|
completion_tokens: int |
|
|
total_tokens: int |
|
|
|
|
|
class ChatCompletionResponse(BaseModel): |
|
|
id: str |
|
|
object: str = "chat.completion" |
|
|
created: int |
|
|
model: str |
|
|
choices: List[ChatCompletionChoice] |
|
|
usage: Usage |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
return {"status": "ok", "timestamp": time.time()} |
|
|
|
|
|
|
|
|
@app.post("/v1/chat/completions") |
|
|
async def chat_completion(request: Request): |
|
|
try: |
|
|
|
|
|
data = await request.json() |
|
|
logger.info(f"Received request: {data}") |
|
|
|
|
|
|
|
|
messages = data.get("messages", []) |
|
|
model = data.get("model", "granite-3-2-8b-instruct") |
|
|
temperature = data.get("temperature", 0.7) |
|
|
top_p = data.get("top_p", 0.9) |
|
|
max_tokens = data.get("max_tokens", 2048) |
|
|
|
|
|
|
|
|
url = "https://d18n68ssusgr7r.cloudfront.net/v1/chat/completions" |
|
|
headers = { |
|
|
"Content-Type": "application/json", |
|
|
"Authorization": "Bearer 89de4a8b-9dc6-4617-86a0-28690278b651" |
|
|
} |
|
|
|
|
|
|
|
|
granite_data = { |
|
|
"messages": messages, |
|
|
"model": model, |
|
|
"max_tokens": max_tokens, |
|
|
"temperature": temperature, |
|
|
"top_p": top_p |
|
|
} |
|
|
|
|
|
logger.info(f"Sending request to granite API: {granite_data}") |
|
|
response = requests.post(url, headers=headers, json=granite_data) |
|
|
logger.info(f"Granite API response status: {response.status_code}") |
|
|
|
|
|
if response.status_code != 200: |
|
|
logger.error(f"Error from granite API: {response.text}") |
|
|
return { |
|
|
"error": { |
|
|
"message": f"Error from upstream API: {response.text}", |
|
|
"type": "api_error", |
|
|
"status": response.status_code |
|
|
} |
|
|
} |
|
|
|
|
|
try: |
|
|
response_json = response.json() |
|
|
logger.info(f"Granite API response: {response_json}") |
|
|
except json.JSONDecodeError: |
|
|
logger.error(f"Failed to parse JSON response: {response.text}") |
|
|
response_json = {"error": "Failed to parse response"} |
|
|
|
|
|
|
|
|
assistant_message = "" |
|
|
if "choices" in response_json and len(response_json["choices"]) > 0: |
|
|
assistant_message = response_json["choices"][0]["message"]["content"] |
|
|
else: |
|
|
|
|
|
assistant_message = str(response_json) |
|
|
|
|
|
|
|
|
prompt_tokens = sum(len(msg.get("content", "").split()) for msg in messages) |
|
|
completion_tokens = len(assistant_message.split()) |
|
|
|
|
|
|
|
|
openai_response = { |
|
|
"id": f"chatcmpl-{int(time.time())}", |
|
|
"object": "chat.completion", |
|
|
"created": int(time.time()), |
|
|
"model": model, |
|
|
"choices": [ |
|
|
{ |
|
|
"index": 0, |
|
|
"message": { |
|
|
"role": "assistant", |
|
|
"content": assistant_message |
|
|
}, |
|
|
"finish_reason": "stop" |
|
|
} |
|
|
], |
|
|
"usage": { |
|
|
"prompt_tokens": prompt_tokens, |
|
|
"completion_tokens": completion_tokens, |
|
|
"total_tokens": prompt_tokens + completion_tokens |
|
|
} |
|
|
} |
|
|
|
|
|
logger.info(f"Returning OpenAI-compatible response") |
|
|
return openai_response |
|
|
except Exception as e: |
|
|
logger.exception(f"Exception in chat_completion: {str(e)}") |
|
|
return { |
|
|
"error": { |
|
|
"message": f"Internal server error: {str(e)}", |
|
|
"type": "server_error", |
|
|
"status": 500 |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@app.post("/raw/chat/completions") |
|
|
async def raw_chat_completion(request: Request): |
|
|
try: |
|
|
data = await request.json() |
|
|
logger.info(f"Received raw request: {data}") |
|
|
|
|
|
|
|
|
url = "https://d18n68ssusgr7r.cloudfront.net/v1/chat/completions" |
|
|
headers = { |
|
|
"Content-Type": "application/json", |
|
|
"Authorization": "Bearer 89de4a8b-9dc6-4617-86a0-28690278b651" |
|
|
} |
|
|
|
|
|
response = requests.post(url, headers=headers, json=data) |
|
|
logger.info(f"Raw API response status: {response.status_code}") |
|
|
|
|
|
try: |
|
|
result = response.json() |
|
|
return result |
|
|
except json.JSONDecodeError: |
|
|
logger.error(f"Failed to parse raw JSON response: {response.text}") |
|
|
return {"error": "Failed to parse response", "raw_response": response.text} |
|
|
except Exception as e: |
|
|
logger.exception(f"Exception in raw_chat_completion: {str(e)}") |
|
|
return {"error": str(e)} |
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
return { |
|
|
"message": "Welcome to the OpenAI-Compatible Chat API", |
|
|
"status": "running", |
|
|
"endpoints": { |
|
|
"/v1/chat/completions": "OpenAI-compatible chat completions endpoint", |
|
|
"/raw/chat/completions": "Direct passthrough to the granite API", |
|
|
"/health": "Health check endpoint" |
|
|
} |
|
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
logger.info("Starting application on port 7860") |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |