Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import uvicorn | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM | |
| import time | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import os | |
| import logging | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="YAH Tech AI API", | |
| description="AI Assistant API with dynamic model loading from HF repo", | |
| version="1.0.0" | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| class YAHBot: | |
| def __init__(self): | |
| self.repo_id = "Adedoyinjames/brain-ai" # Your HF repo | |
| self.tokenizer = None | |
| self.model = None | |
| self.model_type = None | |
| self._load_model() | |
| def _load_model(self): | |
| """Load the model from Hugging Face repo""" | |
| try: | |
| logger.info(f"π Loading AI model from {self.repo_id}...") | |
| # Load tokenizer and model from your repo | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| self.repo_id, | |
| trust_remote_code=True | |
| ) | |
| # Try to detect model type and load accordingly | |
| try: | |
| # First try CausalLM (for models like Mistral, Phi-3, etc.) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.repo_id, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True | |
| ) | |
| self.model_type = "causal" | |
| logger.info("β Loaded as CausalLM model") | |
| except Exception as e: | |
| logger.warning(f"Failed to load as CausalLM: {e}, trying Seq2Seq...") | |
| # Fall back to Seq2Seq (for models like T5, etc.) | |
| self.model = AutoModelForSeq2SeqLM.from_pretrained( | |
| self.repo_id, | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| self.model_type = "seq2seq" | |
| logger.info("β Loaded as Seq2Seq model") | |
| logger.info("β AI model loaded successfully from HF repo!") | |
| except Exception as e: | |
| logger.error(f"β Failed to load AI model from {self.repo_id}: {e}") | |
| self.model = None | |
| self.tokenizer = None | |
| self.model_type = None | |
| def _reload_model_if_needed(self): | |
| """Reload model if it's not loaded (for recovery)""" | |
| if self.model is None or self.tokenizer is None: | |
| logger.info("π Attempting to reload model...") | |
| self._load_model() | |
| def generate_response(self, user_input): | |
| """Generate response using AI model""" | |
| self._reload_model_if_needed() | |
| if self.model and self.tokenizer: | |
| try: | |
| # Format prompt based on model type | |
| if self.model_type == "causal": | |
| # For causal models (Mistral, Phi-3, etc.) | |
| prompt = f"<|user|>\n{user_input}\n<|assistant|>\n" | |
| else: | |
| # For seq2seq models (T5, etc.) | |
| prompt = f"Question: {user_input}\nAnswer: " | |
| # Tokenize input | |
| inputs = self.tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| max_length=512, | |
| truncation=True, | |
| padding=True | |
| ) | |
| # Move to same device as model | |
| device = next(self.model.parameters()).device | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # Generate response based on model type | |
| with torch.no_grad(): | |
| if self.model_type == "causal": | |
| outputs = self.model.generate( | |
| inputs.input_ids, | |
| max_new_tokens=150, | |
| num_return_sequences=1, | |
| temperature=0.7, | |
| do_sample=True, | |
| pad_token_id=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| ) | |
| else: | |
| outputs = self.model.generate( | |
| inputs.input_ids, | |
| max_length=150, | |
| num_return_sequences=1, | |
| temperature=0.7, | |
| do_sample=True, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| ) | |
| # Decode response | |
| response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Clean up response for causal models | |
| if self.model_type == "causal": | |
| if prompt in response: | |
| response = response.replace(prompt, "").strip() | |
| return response | |
| except Exception as e: | |
| logger.error(f"Model generation error: {str(e)}") | |
| return "I apologize, but I'm having trouble processing your question right now." | |
| return "AI model is not available. Please check if the model is properly loaded." | |
| def reload_model(self): | |
| """Force reload the model from HF repo""" | |
| logger.info("π Manually reloading model from HF repo...") | |
| self._load_model() | |
| return self.model is not None | |
| # Initialize the bot globally | |
| yah_bot = YAHBot() | |
| # Request/Response models | |
| class ChatRequest(BaseModel): | |
| message: str | |
| class ChatResponse(BaseModel): | |
| response: str | |
| status: str | |
| timestamp: float | |
| model_type: str = None | |
| class HealthResponse(BaseModel): | |
| status: str | |
| service: str | |
| timestamp: float | |
| model_loaded: bool | |
| model_repo: str | |
| model_type: str = None | |
| class ReloadResponse(BaseModel): | |
| status: str | |
| message: str | |
| timestamp: float | |
| # API Endpoints | |
| async def root(): | |
| return { | |
| "message": "YAH Tech AI API is running", | |
| "status": "active", | |
| "model_repo": yah_bot.repo_id, | |
| "model_loaded": yah_bot.model is not None, | |
| "endpoints": { | |
| "chat": "POST /api/chat", | |
| "health": "GET /api/health", | |
| "reload": "POST /api/reload" | |
| } | |
| } | |
| async def chat_endpoint(request: ChatRequest): | |
| """ | |
| Main chat endpoint - Send a message and get AI response | |
| """ | |
| try: | |
| response = yah_bot.generate_response(request.message) | |
| return ChatResponse( | |
| response=response, | |
| status="success", | |
| timestamp=time.time(), | |
| model_type=yah_bot.model_type | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}") | |
| async def health_check(): | |
| return HealthResponse( | |
| status="healthy", | |
| service="YAH Tech AI API", | |
| timestamp=time.time(), | |
| model_loaded=yah_bot.model is not None, | |
| model_repo=yah_bot.repo_id, | |
| model_type=yah_bot.model_type | |
| ) | |
| async def reload_model(): | |
| """ | |
| Manually reload the model from Hugging Face repo | |
| Use this after updating your model in the repo | |
| """ | |
| try: | |
| success = yah_bot.reload_model() | |
| if success: | |
| return ReloadResponse( | |
| status="success", | |
| message="Model reloaded successfully from HF repo", | |
| timestamp=time.time() | |
| ) | |
| else: | |
| return ReloadResponse( | |
| status="error", | |
| message="Failed to reload model", | |
| timestamp=time.time() | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error reloading model: {str(e)}") | |
| # For Hugging Face Spaces | |
| def get_app(): | |
| return app | |
| if __name__ == "__main__": | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=7860, | |
| log_level="info" | |
| ) |