import os import sys import logging from typing import Optional from fastapi import FastAPI, HTTPException, BackgroundTasks from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from pydantic import BaseModel, Field import uvicorn from datetime import datetime # Add src to path for imports sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'src')) from inference.model import AgriQAAssistant # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize FastAPI app app = FastAPI( title="AgriQA Assistant", description="Agricultural assistant chatbot for farming guidance", version="1.0.0", docs_url="/docs", redoc_url="/redoc" ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global model instance assistant = None # Pydantic models for request/response class ChatRequest(BaseModel): question: str = Field(..., min_length=1, max_length=1000, description="Agricultural question") max_length: Optional[int] = Field(512, ge=50, le=1000, description="Maximum response length") class ChatResponse(BaseModel): answer: str = Field(..., description="Agricultural guidance response") response_time: float = Field(..., description="Response time in seconds") model_info: dict = Field(..., description="Model information") class HealthResponse(BaseModel): status: str = Field(..., description="Service status") model_loaded: bool = Field(..., description="Whether model is loaded") model_info: Optional[dict] = Field(None, description="Model information") timestamp: str = Field(..., description="Current timestamp") class ErrorResponse(BaseModel): error: str = Field(..., description="Error message") timestamp: str = Field(..., description="Error timestamp") @app.on_event("startup") async def startup_event(): global assistant try: logger.info("Loading agriQA assistant model from Hugging Face...") model_path = os.getenv("MODEL_PATH", "nada013/agriqa-assistant") assistant = AgriQAAssistant(model_path) logger.info("Model loaded successfully from Hugging Face") except Exception as e: logger.error(f"Failed to load model: {e}") assistant = None @app.get("/", response_model=dict) async def root(): return { "message": "AgriQA Assistant", "version": "1.0.0", "description": "Agricultural assistant chatbot for farming guidance", "endpoints": { "chat": "/chat - Ask agricultural questions", "health": "/health - Check API health", "docs": "/docs - API documentation" } } @app.post("/chat", response_model=ChatResponse) async def chat(request: ChatRequest): global assistant if assistant is None: raise HTTPException(status_code=503, detail="Model not loaded") try: # Generate response response = assistant.generate_response( question=request.question, max_length=request.max_length ) # Check for errors if 'error' in response: raise HTTPException(status_code=500, detail=response['error']) return ChatResponse( answer=response['answer'], response_time=response['response_time'], model_info=response['model_info'] ) except Exception as e: logger.error(f"Error in chat endpoint: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/health", response_model=HealthResponse) async def health_check(): global assistant model_loaded = assistant is not None model_info = None if model_loaded: try: model_info = assistant.get_model_info() status = "healthy" except Exception as e: logger.error(f"Health check failed: {e}") status = "unhealthy" model_loaded = False else: status = "unhealthy" return HealthResponse( status=status, model_loaded=model_loaded, model_info=model_info, timestamp=datetime.now().isoformat() ) @app.exception_handler(Exception) async def global_exception_handler(request, exc): logger.error(f"Unhandled exception: {exc}") return JSONResponse( status_code=500, content=ErrorResponse( error="Internal server error", timestamp=datetime.now().isoformat() ).dict() ) # Example usage endpoint @app.get("/examples", response_model=dict) async def get_examples(): """Get example agricultural questions.""" return { "examples": [ "How to control aphid infestation in mustard crops?", "What is the best time to plant tomatoes?", "How to treat white diarrhoea in poultry?", "What fertilizer should I use for coconut plants?", "How to preserve potato tubers for 7-8 months?", "What are the symptoms of blight disease in potatoes?", "How to increase milk production in cows?", "What is the recommended spacing for cucumber cultivation?", "How to control fruit borer in mango trees?", "What are the best practices for organic farming?" ] } if __name__ == "__main__": # Run the server uvicorn.run( "main:app", host="0.0.0.0", port=7860, reload=True, log_level="info" )