| """Generation specific endpoints""" |
|
|
| from fastapi import APIRouter, HTTPException |
| import pydantic |
| from datetime import datetime |
| import logging |
| from typing import Optional |
|
|
| from services import generate_model_manager |
|
|
| router = APIRouter() |
| logger = logging.getLogger(__name__) |
|
|
|
|
| class GenerateRequest(pydantic.BaseModel): |
| input_text: str |
| max_length: Optional[int] = 128 |
| num_beams: Optional[int] = 4 |
|
|
|
|
| class GenerateResponse(pydantic.BaseModel): |
| input_text: str |
| generated_text: str |
| timestamp: str |
|
|
|
|
| @router.post("/generate", response_model=GenerateResponse, tags=["Text Generation"]) |
| async def generate_text(request: GenerateRequest): |
| """ |
| Generate text using the T5 model |
| |
| - **input_text**: The input text for generation |
| - **max_length**: Maximum length of generated text (default: 128) |
| - **num_beams**: Number of beams for beam search (default: 4) |
| |
| Returns generated text |
| """ |
| try: |
| |
| result = generate_model_manager.generate( |
| request.input_text, |
| max_length=request.max_length, |
| num_beams=request.num_beams |
| ) |
| |
| |
| response = GenerateResponse( |
| input_text=request.input_text, |
| generated_text=result, |
| timestamp=datetime.now().isoformat() |
| ) |
| |
| logger.info(f"Generated text: {result}") |
| return response |
| |
| except Exception as e: |
| logger.error(f"Generation error: {str(e)}") |
| raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") |
|
|