File size: 5,626 Bytes
571aca4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178

import os
import sys
import logging
from typing import Optional
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
import uvicorn
from datetime import datetime

# Add src to path for imports
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'src'))

from inference.model import AgriQAAssistant

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize FastAPI app
app = FastAPI(
    title="AgriQA Assistant",
    description="Agricultural assistant chatbot for farming guidance",
    version="1.0.0",
    docs_url="/docs",
    redoc_url="/redoc"
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"], 
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Global model instance
assistant = None

# Pydantic models for request/response
class ChatRequest(BaseModel):
    question: str = Field(..., min_length=1, max_length=1000, description="Agricultural question")
    max_length: Optional[int] = Field(512, ge=50, le=1000, description="Maximum response length")

class ChatResponse(BaseModel):
    answer: str = Field(..., description="Agricultural guidance response")
    response_time: float = Field(..., description="Response time in seconds")
    model_info: dict = Field(..., description="Model information")

class HealthResponse(BaseModel):
    status: str = Field(..., description="Service status")
    model_loaded: bool = Field(..., description="Whether model is loaded")
    model_info: Optional[dict] = Field(None, description="Model information")
    timestamp: str = Field(..., description="Current timestamp")

class ErrorResponse(BaseModel):
    error: str = Field(..., description="Error message")
    timestamp: str = Field(..., description="Error timestamp")

@app.on_event("startup")
async def startup_event():
    global assistant
    try:
        logger.info("Loading agriQA assistant model from Hugging Face...")
        model_path = os.getenv("MODEL_PATH", "nada013/agriqa-assistant")
        assistant = AgriQAAssistant(model_path)
        logger.info("Model loaded successfully from Hugging Face")
    except Exception as e:
        logger.error(f"Failed to load model: {e}")
        assistant = None

@app.get("/", response_model=dict)
async def root():
    return {
        "message": "AgriQA Assistant",
        "version": "1.0.0",
        "description": "Agricultural assistant chatbot for farming guidance",
        "endpoints": {
            "chat": "/chat - Ask agricultural questions",
            "health": "/health - Check API health",
            "docs": "/docs - API documentation"
        }
    }

@app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest):
    global assistant
    
    if assistant is None:
        raise HTTPException(status_code=503, detail="Model not loaded")
    
    try:
        # Generate response
        response = assistant.generate_response(
            question=request.question,
            max_length=request.max_length
        )
        
        # Check for errors
        if 'error' in response:
            raise HTTPException(status_code=500, detail=response['error'])
        
        return ChatResponse(
            answer=response['answer'],
            response_time=response['response_time'],
            model_info=response['model_info']
        )
        
    except Exception as e:
        logger.error(f"Error in chat endpoint: {e}")
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/health", response_model=HealthResponse)
async def health_check():
    global assistant
    
    model_loaded = assistant is not None
    model_info = None
    
    if model_loaded:
        try:
            model_info = assistant.get_model_info()
            status = "healthy"
        except Exception as e:
            logger.error(f"Health check failed: {e}")
            status = "unhealthy"
            model_loaded = False
    else:
        status = "unhealthy"
    
    return HealthResponse(
        status=status,
        model_loaded=model_loaded,
        model_info=model_info,
        timestamp=datetime.now().isoformat()
    )

@app.exception_handler(Exception)
async def global_exception_handler(request, exc):
    logger.error(f"Unhandled exception: {exc}")
    return JSONResponse(
        status_code=500,
        content=ErrorResponse(
            error="Internal server error",
            timestamp=datetime.now().isoformat()
        ).dict()
    )

# Example usage endpoint
@app.get("/examples", response_model=dict)
async def get_examples():
    """Get example agricultural questions."""
    return {
        "examples": [
            "How to control aphid infestation in mustard crops?",
            "What is the best time to plant tomatoes?",
            "How to treat white diarrhoea in poultry?",
            "What fertilizer should I use for coconut plants?",
            "How to preserve potato tubers for 7-8 months?",
            "What are the symptoms of blight disease in potatoes?",
            "How to increase milk production in cows?",
            "What is the recommended spacing for cucumber cultivation?",
            "How to control fruit borer in mango trees?",
            "What are the best practices for organic farming?"
        ]
    }

if __name__ == "__main__":
    # Run the server
    uvicorn.run(
        "main:app",
        host="0.0.0.0",
        port=7860,
        reload=True,
        log_level="info"
    )