othdu's picture
Upload 18 files
571aca4 verified
import os
import sys
import logging
from typing import Optional
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
import uvicorn
from datetime import datetime
# Add src to path for imports
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'src'))
from inference.model import AgriQAAssistant
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(
title="AgriQA Assistant",
description="Agricultural assistant chatbot for farming guidance",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global model instance
assistant = None
# Pydantic models for request/response
class ChatRequest(BaseModel):
question: str = Field(..., min_length=1, max_length=1000, description="Agricultural question")
max_length: Optional[int] = Field(512, ge=50, le=1000, description="Maximum response length")
class ChatResponse(BaseModel):
answer: str = Field(..., description="Agricultural guidance response")
response_time: float = Field(..., description="Response time in seconds")
model_info: dict = Field(..., description="Model information")
class HealthResponse(BaseModel):
status: str = Field(..., description="Service status")
model_loaded: bool = Field(..., description="Whether model is loaded")
model_info: Optional[dict] = Field(None, description="Model information")
timestamp: str = Field(..., description="Current timestamp")
class ErrorResponse(BaseModel):
error: str = Field(..., description="Error message")
timestamp: str = Field(..., description="Error timestamp")
@app.on_event("startup")
async def startup_event():
global assistant
try:
logger.info("Loading agriQA assistant model from Hugging Face...")
model_path = os.getenv("MODEL_PATH", "nada013/agriqa-assistant")
assistant = AgriQAAssistant(model_path)
logger.info("Model loaded successfully from Hugging Face")
except Exception as e:
logger.error(f"Failed to load model: {e}")
assistant = None
@app.get("/", response_model=dict)
async def root():
return {
"message": "AgriQA Assistant",
"version": "1.0.0",
"description": "Agricultural assistant chatbot for farming guidance",
"endpoints": {
"chat": "/chat - Ask agricultural questions",
"health": "/health - Check API health",
"docs": "/docs - API documentation"
}
}
@app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest):
global assistant
if assistant is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
# Generate response
response = assistant.generate_response(
question=request.question,
max_length=request.max_length
)
# Check for errors
if 'error' in response:
raise HTTPException(status_code=500, detail=response['error'])
return ChatResponse(
answer=response['answer'],
response_time=response['response_time'],
model_info=response['model_info']
)
except Exception as e:
logger.error(f"Error in chat endpoint: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health", response_model=HealthResponse)
async def health_check():
global assistant
model_loaded = assistant is not None
model_info = None
if model_loaded:
try:
model_info = assistant.get_model_info()
status = "healthy"
except Exception as e:
logger.error(f"Health check failed: {e}")
status = "unhealthy"
model_loaded = False
else:
status = "unhealthy"
return HealthResponse(
status=status,
model_loaded=model_loaded,
model_info=model_info,
timestamp=datetime.now().isoformat()
)
@app.exception_handler(Exception)
async def global_exception_handler(request, exc):
logger.error(f"Unhandled exception: {exc}")
return JSONResponse(
status_code=500,
content=ErrorResponse(
error="Internal server error",
timestamp=datetime.now().isoformat()
).dict()
)
# Example usage endpoint
@app.get("/examples", response_model=dict)
async def get_examples():
"""Get example agricultural questions."""
return {
"examples": [
"How to control aphid infestation in mustard crops?",
"What is the best time to plant tomatoes?",
"How to treat white diarrhoea in poultry?",
"What fertilizer should I use for coconut plants?",
"How to preserve potato tubers for 7-8 months?",
"What are the symptoms of blight disease in potatoes?",
"How to increase milk production in cows?",
"What is the recommended spacing for cucumber cultivation?",
"How to control fruit borer in mango trees?",
"What are the best practices for organic farming?"
]
}
if __name__ == "__main__":
# Run the server
uvicorn.run(
"main:app",
host="0.0.0.0",
port=7860,
reload=True,
log_level="info"
)