File size: 6,887 Bytes
db7889b d33d331 db7889b d33d331 db7889b d33d331 db7889b d33d331 db7889b d33d331 db7889b d33d331 db7889b d33d331 db7889b d33d331 db7889b d33d331 db7889b d33d331 db7889b d33d331 db7889b d33d331 db7889b d33d331 db7889b d33d331 db7889b d33d331 db7889b d33d331 db7889b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
# Entry point file for Hugging Face Spaces - OpenAI Compatible
import uvicorn
from fastapi import FastAPI, HTTPException, Request
import requests
from pydantic import BaseModel, Field
from typing import Optional, List, Dict, Any, Literal
import json
import time
import logging
import sys
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)
app = FastAPI(title="OpenAI-Compatible Chat API",
description="A FastAPI application that provides an OpenAI-compatible interface")
# Models for OpenAI compatibility
class Message(BaseModel):
role: str
content: str
name: Optional[str] = None
class ChatCompletionRequest(BaseModel):
model: str = "granite-3-2-8b-instruct"
messages: List[Message]
temperature: Optional[float] = 0.7
top_p: Optional[float] = 0.9
max_tokens: Optional[int] = 2048
stream: Optional[bool] = False
class ChatCompletionChoice(BaseModel):
index: int
message: Message
finish_reason: str = "stop"
class Usage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class ChatCompletionResponse(BaseModel):
id: str
object: str = "chat.completion"
created: int
model: str
choices: List[ChatCompletionChoice]
usage: Usage
# Simple API endpoint for debugging
@app.get("/health")
async def health_check():
return {"status": "ok", "timestamp": time.time()}
# Custom endpoints for graniteAI
@app.post("/v1/chat/completions")
async def chat_completion(request: Request):
try:
# Get raw request data
data = await request.json()
logger.info(f"Received request: {data}")
# Extract messages
messages = data.get("messages", [])
model = data.get("model", "granite-3-2-8b-instruct")
temperature = data.get("temperature", 0.7)
top_p = data.get("top_p", 0.9)
max_tokens = data.get("max_tokens", 2048)
# Forward to granite API
url = "https://d18n68ssusgr7r.cloudfront.net/v1/chat/completions"
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer 89de4a8b-9dc6-4617-86a0-28690278b651"
}
# Format request for granite API
granite_data = {
"messages": messages,
"model": model,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p
}
logger.info(f"Sending request to granite API: {granite_data}")
response = requests.post(url, headers=headers, json=granite_data)
logger.info(f"Granite API response status: {response.status_code}")
if response.status_code != 200:
logger.error(f"Error from granite API: {response.text}")
return {
"error": {
"message": f"Error from upstream API: {response.text}",
"type": "api_error",
"status": response.status_code
}
}
try:
response_json = response.json()
logger.info(f"Granite API response: {response_json}")
except json.JSONDecodeError:
logger.error(f"Failed to parse JSON response: {response.text}")
response_json = {"error": "Failed to parse response"}
# Extract the assistant message
assistant_message = ""
if "choices" in response_json and len(response_json["choices"]) > 0:
assistant_message = response_json["choices"][0]["message"]["content"]
else:
# Fallback in case the response structure is different
assistant_message = str(response_json)
# Estimate token counts (very rough estimation)
prompt_tokens = sum(len(msg.get("content", "").split()) for msg in messages)
completion_tokens = len(assistant_message.split())
# Format the response to match OpenAI's format
openai_response = {
"id": f"chatcmpl-{int(time.time())}",
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": assistant_message
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens
}
}
logger.info(f"Returning OpenAI-compatible response")
return openai_response
except Exception as e:
logger.exception(f"Exception in chat_completion: {str(e)}")
return {
"error": {
"message": f"Internal server error: {str(e)}",
"type": "server_error",
"status": 500
}
}
# Alternative version of the endpoint that directly passes through the raw granite API response
@app.post("/raw/chat/completions")
async def raw_chat_completion(request: Request):
try:
data = await request.json()
logger.info(f"Received raw request: {data}")
# Forward to granite API
url = "https://d18n68ssusgr7r.cloudfront.net/v1/chat/completions"
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer 89de4a8b-9dc6-4617-86a0-28690278b651"
}
response = requests.post(url, headers=headers, json=data)
logger.info(f"Raw API response status: {response.status_code}")
try:
result = response.json()
return result
except json.JSONDecodeError:
logger.error(f"Failed to parse raw JSON response: {response.text}")
return {"error": "Failed to parse response", "raw_response": response.text}
except Exception as e:
logger.exception(f"Exception in raw_chat_completion: {str(e)}")
return {"error": str(e)}
@app.get("/")
async def root():
return {
"message": "Welcome to the OpenAI-Compatible Chat API",
"status": "running",
"endpoints": {
"/v1/chat/completions": "OpenAI-compatible chat completions endpoint",
"/raw/chat/completions": "Direct passthrough to the granite API",
"/health": "Health check endpoint"
}
}
if __name__ == "__main__":
logger.info("Starting application on port 7860")
uvicorn.run(app, host="0.0.0.0", port=7860) |