YAhbot / app.py
Adedoyinjames's picture
Update app.py
630df96 verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import uvicorn
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
import time
from fastapi.middleware.cors import CORSMiddleware
import os
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(
title="YAH Tech AI API",
description="AI Assistant API with dynamic model loading from HF repo",
version="1.0.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class YAHBot:
def __init__(self):
self.repo_id = "Adedoyinjames/brain-ai" # Your HF repo
self.tokenizer = None
self.model = None
self.model_type = None
self._load_model()
def _load_model(self):
"""Load the model from Hugging Face repo"""
try:
logger.info(f"πŸ”„ Loading AI model from {self.repo_id}...")
# Load tokenizer and model from your repo
self.tokenizer = AutoTokenizer.from_pretrained(
self.repo_id,
trust_remote_code=True
)
# Try to detect model type and load accordingly
try:
# First try CausalLM (for models like Mistral, Phi-3, etc.)
self.model = AutoModelForCausalLM.from_pretrained(
self.repo_id,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
low_cpu_mem_usage=True
)
self.model_type = "causal"
logger.info("βœ… Loaded as CausalLM model")
except Exception as e:
logger.warning(f"Failed to load as CausalLM: {e}, trying Seq2Seq...")
# Fall back to Seq2Seq (for models like T5, etc.)
self.model = AutoModelForSeq2SeqLM.from_pretrained(
self.repo_id,
torch_dtype=torch.float16,
device_map="auto"
)
self.model_type = "seq2seq"
logger.info("βœ… Loaded as Seq2Seq model")
logger.info("βœ… AI model loaded successfully from HF repo!")
except Exception as e:
logger.error(f"❌ Failed to load AI model from {self.repo_id}: {e}")
self.model = None
self.tokenizer = None
self.model_type = None
def _reload_model_if_needed(self):
"""Reload model if it's not loaded (for recovery)"""
if self.model is None or self.tokenizer is None:
logger.info("πŸ”„ Attempting to reload model...")
self._load_model()
def generate_response(self, user_input):
"""Generate response using AI model"""
self._reload_model_if_needed()
if self.model and self.tokenizer:
try:
# Format prompt based on model type
if self.model_type == "causal":
# For causal models (Mistral, Phi-3, etc.)
prompt = f"<|user|>\n{user_input}\n<|assistant|>\n"
else:
# For seq2seq models (T5, etc.)
prompt = f"Question: {user_input}\nAnswer: "
# Tokenize input
inputs = self.tokenizer(
prompt,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True
)
# Move to same device as model
device = next(self.model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
# Generate response based on model type
with torch.no_grad():
if self.model_type == "causal":
outputs = self.model.generate(
inputs.input_ids,
max_new_tokens=150,
num_return_sequences=1,
temperature=0.7,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
)
else:
outputs = self.model.generate(
inputs.input_ids,
max_length=150,
num_return_sequences=1,
temperature=0.7,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
)
# Decode response
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clean up response for causal models
if self.model_type == "causal":
if prompt in response:
response = response.replace(prompt, "").strip()
return response
except Exception as e:
logger.error(f"Model generation error: {str(e)}")
return "I apologize, but I'm having trouble processing your question right now."
return "AI model is not available. Please check if the model is properly loaded."
def reload_model(self):
"""Force reload the model from HF repo"""
logger.info("πŸ”„ Manually reloading model from HF repo...")
self._load_model()
return self.model is not None
# Initialize the bot globally
yah_bot = YAHBot()
# Request/Response models
class ChatRequest(BaseModel):
message: str
class ChatResponse(BaseModel):
response: str
status: str
timestamp: float
model_type: str = None
class HealthResponse(BaseModel):
status: str
service: str
timestamp: float
model_loaded: bool
model_repo: str
model_type: str = None
class ReloadResponse(BaseModel):
status: str
message: str
timestamp: float
# API Endpoints
@app.get("/")
async def root():
return {
"message": "YAH Tech AI API is running",
"status": "active",
"model_repo": yah_bot.repo_id,
"model_loaded": yah_bot.model is not None,
"endpoints": {
"chat": "POST /api/chat",
"health": "GET /api/health",
"reload": "POST /api/reload"
}
}
@app.post("/api/chat", response_model=ChatResponse)
async def chat_endpoint(request: ChatRequest):
"""
Main chat endpoint - Send a message and get AI response
"""
try:
response = yah_bot.generate_response(request.message)
return ChatResponse(
response=response,
status="success",
timestamp=time.time(),
model_type=yah_bot.model_type
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
@app.get("/api/health", response_model=HealthResponse)
async def health_check():
return HealthResponse(
status="healthy",
service="YAH Tech AI API",
timestamp=time.time(),
model_loaded=yah_bot.model is not None,
model_repo=yah_bot.repo_id,
model_type=yah_bot.model_type
)
@app.post("/api/reload", response_model=ReloadResponse)
async def reload_model():
"""
Manually reload the model from Hugging Face repo
Use this after updating your model in the repo
"""
try:
success = yah_bot.reload_model()
if success:
return ReloadResponse(
status="success",
message="Model reloaded successfully from HF repo",
timestamp=time.time()
)
else:
return ReloadResponse(
status="error",
message="Failed to reload model",
timestamp=time.time()
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error reloading model: {str(e)}")
# For Hugging Face Spaces
def get_app():
return app
if __name__ == "__main__":
uvicorn.run(
app,
host="0.0.0.0",
port=7860,
log_level="info"
)