SehatOnline / app.py
Noor3's picture
Update app.py
3b36664 verified
# File: app.py - FastAPI implementation for secure medical chatbot
import os
import torch
from fastapi import FastAPI, HTTPException, Depends, Security
from fastapi.security import APIKeyHeader
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import hashlib
import logging
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(title="Secure Medical Chatbot API")
# Setup CORS middleware to control which domains can access your API
app.add_middleware(
CORSMiddleware,
allow_origins=["https://your-website-domain.com"], # Replace with your website domain
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# API key security
API_KEY = os.environ.get("API_KEY", "your-secret-api-key") # Set this securely in production
api_key_header = APIKeyHeader(name="X-API-Key")
# Input model for request validation
class QueryInput(BaseModel):
query: str
# Create a hash function for privacy
def hash_query(query: str) -> str:
return hashlib.sha256(query.encode()).hexdigest()
# Load the model and tokenizer (lazy loading on first request)
model = None
tokenizer = None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model():
global model, tokenizer
if model is None or tokenizer is None:
logger.info("Loading model and tokenizer...")
model_name = "shanover/medbot_godel_v3"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.to(device)
logger.info(f"Model loaded on {device}")
# Authentication dependency
async def verify_api_key(api_key: str = Security(api_key_header)):
if api_key != API_KEY:
raise HTTPException(status_code=403, detail="Invalid API key")
return api_key
# Generate response function
def generate_response(input_text, max_length=512):
input_ids = tokenizer.encode(input_text, return_tensors="pt", max_length=max_length, truncation=True)
input_ids = input_ids.to(device)
with torch.no_grad():
output_ids = model.generate(input_ids)
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return generated_text
@app.on_event("startup")
async def startup_event():
load_model()
@app.post("/api/medical-advice")
async def get_medical_advice(query_input: QueryInput, api_key: str = Depends(verify_api_key)):
try:
query = query_input.query
# Log a hash of the query instead of the query itself for privacy
logger.info(f"Processing query with hash: {hash_query(query)}")
response = generate_response(query)
return {
"response": response,
"status": "success"
}
except Exception as e:
logger.error(f"Error processing query: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
# Health check endpoint
@app.get("/health")
async def health_check():
return {"status": "healthy"}
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)