Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import logging | |
| from typing import Optional | |
| from fastapi import FastAPI, HTTPException, BackgroundTasks | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel, Field | |
| import uvicorn | |
| from datetime import datetime | |
| # Add src to path for imports | |
| sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'src')) | |
| from inference.model import AgriQAAssistant | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="AgriQA Assistant", | |
| description="Agricultural assistant chatbot for farming guidance", | |
| version="1.0.0", | |
| docs_url="/docs", | |
| redoc_url="/redoc" | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global model instance | |
| assistant = None | |
| # Pydantic models for request/response | |
| class ChatRequest(BaseModel): | |
| question: str = Field(..., min_length=1, max_length=1000, description="Agricultural question") | |
| max_length: Optional[int] = Field(512, ge=50, le=1000, description="Maximum response length") | |
| class ChatResponse(BaseModel): | |
| answer: str = Field(..., description="Agricultural guidance response") | |
| response_time: float = Field(..., description="Response time in seconds") | |
| model_info: dict = Field(..., description="Model information") | |
| class HealthResponse(BaseModel): | |
| status: str = Field(..., description="Service status") | |
| model_loaded: bool = Field(..., description="Whether model is loaded") | |
| model_info: Optional[dict] = Field(None, description="Model information") | |
| timestamp: str = Field(..., description="Current timestamp") | |
| class ErrorResponse(BaseModel): | |
| error: str = Field(..., description="Error message") | |
| timestamp: str = Field(..., description="Error timestamp") | |
| async def startup_event(): | |
| global assistant | |
| try: | |
| logger.info("Loading agriQA assistant model from Hugging Face...") | |
| model_path = os.getenv("MODEL_PATH", "nada013/agriqa-assistant") | |
| assistant = AgriQAAssistant(model_path) | |
| logger.info("Model loaded successfully from Hugging Face") | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {e}") | |
| assistant = None | |
| async def root(): | |
| return { | |
| "message": "AgriQA Assistant", | |
| "version": "1.0.0", | |
| "description": "Agricultural assistant chatbot for farming guidance", | |
| "endpoints": { | |
| "chat": "/chat - Ask agricultural questions", | |
| "health": "/health - Check API health", | |
| "docs": "/docs - API documentation" | |
| } | |
| } | |
| async def chat(request: ChatRequest): | |
| global assistant | |
| if assistant is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| try: | |
| # Generate response | |
| response = assistant.generate_response( | |
| question=request.question, | |
| max_length=request.max_length | |
| ) | |
| # Check for errors | |
| if 'error' in response: | |
| raise HTTPException(status_code=500, detail=response['error']) | |
| return ChatResponse( | |
| answer=response['answer'], | |
| response_time=response['response_time'], | |
| model_info=response['model_info'] | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in chat endpoint: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def health_check(): | |
| global assistant | |
| model_loaded = assistant is not None | |
| model_info = None | |
| if model_loaded: | |
| try: | |
| model_info = assistant.get_model_info() | |
| status = "healthy" | |
| except Exception as e: | |
| logger.error(f"Health check failed: {e}") | |
| status = "unhealthy" | |
| model_loaded = False | |
| else: | |
| status = "unhealthy" | |
| return HealthResponse( | |
| status=status, | |
| model_loaded=model_loaded, | |
| model_info=model_info, | |
| timestamp=datetime.now().isoformat() | |
| ) | |
| async def global_exception_handler(request, exc): | |
| logger.error(f"Unhandled exception: {exc}") | |
| return JSONResponse( | |
| status_code=500, | |
| content=ErrorResponse( | |
| error="Internal server error", | |
| timestamp=datetime.now().isoformat() | |
| ).dict() | |
| ) | |
| # Example usage endpoint | |
| async def get_examples(): | |
| """Get example agricultural questions.""" | |
| return { | |
| "examples": [ | |
| "How to control aphid infestation in mustard crops?", | |
| "What is the best time to plant tomatoes?", | |
| "How to treat white diarrhoea in poultry?", | |
| "What fertilizer should I use for coconut plants?", | |
| "How to preserve potato tubers for 7-8 months?", | |
| "What are the symptoms of blight disease in potatoes?", | |
| "How to increase milk production in cows?", | |
| "What is the recommended spacing for cucumber cultivation?", | |
| "How to control fruit borer in mango trees?", | |
| "What are the best practices for organic farming?" | |
| ] | |
| } | |
| if __name__ == "__main__": | |
| # Run the server | |
| uvicorn.run( | |
| "main:app", | |
| host="0.0.0.0", | |
| port=7860, | |
| reload=True, | |
| log_level="info" | |
| ) |