Spaces:
Sleeping
Sleeping
| import base64 | |
| import logging | |
| from typing import List, Optional | |
| from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from pydantic import AnyHttpUrl, BaseModel, UrlConstraints | |
| from contextlib import asynccontextmanager | |
| from PIL import Image | |
| from config import get_settings | |
| import uvicorn | |
| from utils.audio_utils import AudioUtils | |
| from utils.caption_utils import ImageCaptioning | |
| from utils.image_utils import UrlTest | |
| from utils.topic_generation import TopicGenerator | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Pydantic models for request/response | |
| class TopicResponse(BaseModel): | |
| topics: List[str] | |
| caption: Optional[str] | |
| class AudioResponse(BaseModel): | |
| audio_base64: str | |
| class TranscriptionResponse(BaseModel): | |
| audio_transcription: str | |
| # Context manager for startup and shutdown events | |
| async def lifespan(app: FastAPI): | |
| # Startup | |
| app.state.topic_generator = TopicGenerator() | |
| app.state.img_caption = ImageCaptioning() | |
| app.state.audio_utils = AudioUtils() | |
| app.state.url_utils = UrlTest() | |
| logger.info("Application startup complete") | |
| yield | |
| # Shutdown | |
| logger.info("Application shutdown") | |
| app = FastAPI( | |
| title="Rediones API", | |
| lifespan=lifespan, | |
| ) | |
| # CORS | |
| async def startup_event(): | |
| settings = get_settings() | |
| if settings.ALLOWED_ORIGINS: | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=settings.ALLOWED_ORIGINS, | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def root(): | |
| return {"message": "Welcome To Rediones API"} | |
| async def health(): | |
| return {"status": "OK"} | |
| async def generate_topic( | |
| img: UploadFile = File(None), | |
| text: Optional[str] = Form(None), | |
| img_url: Optional[AnyHttpUrl] = Form(None) | |
| ): | |
| try: | |
| if img_url and img: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Only one of image_url or img can be accepted" | |
| ) | |
| if text and not (img or img_url): | |
| generated_topics = app.state.topic_generator.generate_topics(text) | |
| return TopicResponse(topics=generated_topics, caption=None) | |
| if img or img_url: | |
| img_file_object = None | |
| if img: | |
| if not img.filename.lower().endswith((".jpg", ".png", ".jpeg")): | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Image file must be ended with .jpg, .png, .jpeg" | |
| ) | |
| img_file_object = Image.open(img.file) | |
| elif img_url: | |
| img_file_object = app.state.url_utils.load_image(img_url) | |
| capt = app.state.img_caption.combo_model(img_file_object, text) | |
| print(capt) | |
| return TopicResponse(topics=capt["topics"], caption=capt["caption"]) | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Enter text or image. Image URL and image file are mutually exclusive." | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in generate_topic: {str(e)}") | |
| raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred") | |
| async def generate_audio(text: str): | |
| try: | |
| audio_bytes = app.state.audio_utils.speak(text) | |
| audio_base64 = base64.b64encode(audio_bytes).decode("utf-8") | |
| return AudioResponse(audio_base64=audio_base64) | |
| except Exception as e: | |
| logger.error(f"Error in generate_audio: {str(e)}") | |
| raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred") | |
| async def transcribe_audio( | |
| audio: UploadFile = File(..., description="Audio file to be transcribed.") | |
| ): | |
| try: | |
| audio_transcribe = app.state.audio_utils.improved_transcribe(0.8, audio_file=audio.file) | |
| return TranscriptionResponse(audio_transcription=audio_transcribe) | |
| except Exception as e: | |
| logger.error(f"Error in transcribe_audio: {str(e)}") | |
| raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred") | |
| if __name__ == "__main__": | |
| uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) |