Spaces:
Sleeping
Sleeping
File size: 2,732 Bytes
c1fe6d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import typing as t
import os
# from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import logging
from src.utils import (
OpenAIClient,
TogetherAIClient,
GeminiClient,
GroqClient,
MistralClient,
)
from src.models_enums import ModelProvider
# load_dotenv()
assert os.environ['TOGETHER_API_KEY'] is not None
# Configure basic logging to see messages in stdout (and thus in HF Space logs)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class RequestData(BaseModel):
prompt: str
max_tokens: int = 50
system_prompt: t.Optional[str] = None
MODEL_PROVIDER2CLIENT = {
ModelProvider.OPENAI.value: OpenAIClient,
ModelProvider.GEMINI.value: GeminiClient,
ModelProvider.TOGETHERAI.value: TogetherAIClient,
ModelProvider.GROQ.value: GroqClient,
ModelProvider.MISTRAL.value: MistralClient,
}
app = FastAPI()
logger.info("FastAPI app initialized.")
# The application now starts without initializing a specific LLM,
# which makes it more flexible.
@app.post("/generate/{model_provider}/{model_name:path}")
async def generate_text(
model_provider: str,
model_name: str,
request: RequestData
):
"""
Generates text using a specified LLM provider and model.
Example:
POST /generate/togetherai/meta-llama/Llama-3.3-70B-Instruct-Turbo-Free
with body: {"prompt": "...", "max_tokens": 100}
"""
logger.info(f"Received POST request to /generate/{model_provider}/{model_name}.")
# Check if the requested model provider exists
if model_provider not in MODEL_PROVIDER2CLIENT:
logger.error(f"Invalid model provider: {model_provider}")
raise HTTPException(
status_code=400,
detail=f"Invalid model provider: {model_provider}. "
f"Available providers: {[p.value for p in ModelProvider]}"
)
try:
# Get the correct client class and instantiate it dynamically
llm_client_class = MODEL_PROVIDER2CLIENT[model_provider]
llm_client = llm_client_class(model=model_name)
# Call the client's async method
output = await llm_client(
prompt=request.prompt,
system_prompt=request.system_prompt,
max_tokens=request.max_tokens
)
return output
except Exception as e:
logger.error(
f"Error during text generation for {model_provider}/{model_name}: {str(e)}",
exc_info=True
)
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
logger.info("Received GET request to /health.")
return {"status": "ok"} |