Adedoyinjames commited on
Commit
630df96
Β·
verified Β·
1 Parent(s): 98d4d77

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -25
app.py CHANGED
@@ -2,14 +2,20 @@ from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  import uvicorn
4
  import torch
5
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
  import time
7
  from fastapi.middleware.cors import CORSMiddleware
 
 
 
 
 
 
8
 
9
  # Initialize FastAPI app
10
  app = FastAPI(
11
  title="YAH Tech AI API",
12
- description="AI Assistant API for testing",
13
  version="1.0.0"
14
  )
15
 
@@ -24,30 +30,76 @@ app.add_middleware(
24
 
25
  class YAHBot:
26
  def __init__(self):
27
- self.model_name = "google/flan-t5-base"
28
  self.tokenizer = None
29
  self.model = None
 
30
  self._load_model()
31
 
32
  def _load_model(self):
33
- """Load the Hugging Face model"""
34
  try:
35
- print("πŸ”„ Loading AI model...")
36
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
37
- self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
38
- print("βœ… AI model loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  except Exception as e:
40
- print(f"❌ Failed to load AI model: {e}")
41
  self.model = None
42
  self.tokenizer = None
 
 
 
 
 
 
 
43
 
44
  def generate_response(self, user_input):
45
  """Generate response using AI model"""
 
 
46
  if self.model and self.tokenizer:
47
  try:
48
- prompt = f"Question: {user_input}\nAnswer: "
 
 
 
 
 
 
49
 
50
- # Tokenize
51
  inputs = self.tokenizer(
52
  prompt,
53
  return_tensors="pt",
@@ -56,25 +108,53 @@ class YAHBot:
56
  padding=True
57
  )
58
 
59
- # Generate response
 
 
 
 
60
  with torch.no_grad():
61
- outputs = self.model.generate(
62
- inputs.input_ids,
63
- max_length=150,
64
- num_return_sequences=1,
65
- temperature=0.7,
66
- do_sample=True,
67
- pad_token_id=self.tokenizer.pad_token_id,
68
- )
 
 
 
 
 
 
 
 
 
 
 
69
 
 
70
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
71
  return response
72
 
73
  except Exception as e:
74
- print(f"Model error: {str(e)}")
75
  return "I apologize, but I'm having trouble processing your question right now."
76
 
77
- return "AI model is not available."
 
 
 
 
 
 
78
 
79
  # Initialize the bot globally
80
  yah_bot = YAHBot()
@@ -87,11 +167,20 @@ class ChatResponse(BaseModel):
87
  response: str
88
  status: str
89
  timestamp: float
 
90
 
91
  class HealthResponse(BaseModel):
92
  status: str
93
  service: str
94
  timestamp: float
 
 
 
 
 
 
 
 
95
 
96
  # API Endpoints
97
  @app.get("/")
@@ -99,9 +188,12 @@ async def root():
99
  return {
100
  "message": "YAH Tech AI API is running",
101
  "status": "active",
 
 
102
  "endpoints": {
103
  "chat": "POST /api/chat",
104
- "health": "GET /api/health"
 
105
  }
106
  }
107
 
@@ -116,7 +208,8 @@ async def chat_endpoint(request: ChatRequest):
116
  return ChatResponse(
117
  response=response,
118
  status="success",
119
- timestamp=time.time()
 
120
  )
121
  except Exception as e:
122
  raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
@@ -126,9 +219,35 @@ async def health_check():
126
  return HealthResponse(
127
  status="healthy",
128
  service="YAH Tech AI API",
129
- timestamp=time.time()
 
 
 
130
  )
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  # For Hugging Face Spaces
133
  def get_app():
134
  return app
 
2
  from pydantic import BaseModel
3
  import uvicorn
4
  import torch
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
6
  import time
7
  from fastapi.middleware.cors import CORSMiddleware
8
+ import os
9
+ import logging
10
+
11
+ # Set up logging
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
 
15
  # Initialize FastAPI app
16
  app = FastAPI(
17
  title="YAH Tech AI API",
18
+ description="AI Assistant API with dynamic model loading from HF repo",
19
  version="1.0.0"
20
  )
21
 
 
30
 
31
  class YAHBot:
32
  def __init__(self):
33
+ self.repo_id = "Adedoyinjames/brain-ai" # Your HF repo
34
  self.tokenizer = None
35
  self.model = None
36
+ self.model_type = None
37
  self._load_model()
38
 
39
  def _load_model(self):
40
+ """Load the model from Hugging Face repo"""
41
  try:
42
+ logger.info(f"πŸ”„ Loading AI model from {self.repo_id}...")
43
+
44
+ # Load tokenizer and model from your repo
45
+ self.tokenizer = AutoTokenizer.from_pretrained(
46
+ self.repo_id,
47
+ trust_remote_code=True
48
+ )
49
+
50
+ # Try to detect model type and load accordingly
51
+ try:
52
+ # First try CausalLM (for models like Mistral, Phi-3, etc.)
53
+ self.model = AutoModelForCausalLM.from_pretrained(
54
+ self.repo_id,
55
+ torch_dtype=torch.float16,
56
+ device_map="auto",
57
+ trust_remote_code=True,
58
+ low_cpu_mem_usage=True
59
+ )
60
+ self.model_type = "causal"
61
+ logger.info("βœ… Loaded as CausalLM model")
62
+
63
+ except Exception as e:
64
+ logger.warning(f"Failed to load as CausalLM: {e}, trying Seq2Seq...")
65
+ # Fall back to Seq2Seq (for models like T5, etc.)
66
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(
67
+ self.repo_id,
68
+ torch_dtype=torch.float16,
69
+ device_map="auto"
70
+ )
71
+ self.model_type = "seq2seq"
72
+ logger.info("βœ… Loaded as Seq2Seq model")
73
+
74
+ logger.info("βœ… AI model loaded successfully from HF repo!")
75
+
76
  except Exception as e:
77
+ logger.error(f"❌ Failed to load AI model from {self.repo_id}: {e}")
78
  self.model = None
79
  self.tokenizer = None
80
+ self.model_type = None
81
+
82
+ def _reload_model_if_needed(self):
83
+ """Reload model if it's not loaded (for recovery)"""
84
+ if self.model is None or self.tokenizer is None:
85
+ logger.info("πŸ”„ Attempting to reload model...")
86
+ self._load_model()
87
 
88
  def generate_response(self, user_input):
89
  """Generate response using AI model"""
90
+ self._reload_model_if_needed()
91
+
92
  if self.model and self.tokenizer:
93
  try:
94
+ # Format prompt based on model type
95
+ if self.model_type == "causal":
96
+ # For causal models (Mistral, Phi-3, etc.)
97
+ prompt = f"<|user|>\n{user_input}\n<|assistant|>\n"
98
+ else:
99
+ # For seq2seq models (T5, etc.)
100
+ prompt = f"Question: {user_input}\nAnswer: "
101
 
102
+ # Tokenize input
103
  inputs = self.tokenizer(
104
  prompt,
105
  return_tensors="pt",
 
108
  padding=True
109
  )
110
 
111
+ # Move to same device as model
112
+ device = next(self.model.parameters()).device
113
+ inputs = {k: v.to(device) for k, v in inputs.items()}
114
+
115
+ # Generate response based on model type
116
  with torch.no_grad():
117
+ if self.model_type == "causal":
118
+ outputs = self.model.generate(
119
+ inputs.input_ids,
120
+ max_new_tokens=150,
121
+ num_return_sequences=1,
122
+ temperature=0.7,
123
+ do_sample=True,
124
+ pad_token_id=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id,
125
+ eos_token_id=self.tokenizer.eos_token_id,
126
+ )
127
+ else:
128
+ outputs = self.model.generate(
129
+ inputs.input_ids,
130
+ max_length=150,
131
+ num_return_sequences=1,
132
+ temperature=0.7,
133
+ do_sample=True,
134
+ pad_token_id=self.tokenizer.pad_token_id,
135
+ )
136
 
137
+ # Decode response
138
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
139
+
140
+ # Clean up response for causal models
141
+ if self.model_type == "causal":
142
+ if prompt in response:
143
+ response = response.replace(prompt, "").strip()
144
+
145
  return response
146
 
147
  except Exception as e:
148
+ logger.error(f"Model generation error: {str(e)}")
149
  return "I apologize, but I'm having trouble processing your question right now."
150
 
151
+ return "AI model is not available. Please check if the model is properly loaded."
152
+
153
+ def reload_model(self):
154
+ """Force reload the model from HF repo"""
155
+ logger.info("πŸ”„ Manually reloading model from HF repo...")
156
+ self._load_model()
157
+ return self.model is not None
158
 
159
  # Initialize the bot globally
160
  yah_bot = YAHBot()
 
167
  response: str
168
  status: str
169
  timestamp: float
170
+ model_type: str = None
171
 
172
  class HealthResponse(BaseModel):
173
  status: str
174
  service: str
175
  timestamp: float
176
+ model_loaded: bool
177
+ model_repo: str
178
+ model_type: str = None
179
+
180
+ class ReloadResponse(BaseModel):
181
+ status: str
182
+ message: str
183
+ timestamp: float
184
 
185
  # API Endpoints
186
  @app.get("/")
 
188
  return {
189
  "message": "YAH Tech AI API is running",
190
  "status": "active",
191
+ "model_repo": yah_bot.repo_id,
192
+ "model_loaded": yah_bot.model is not None,
193
  "endpoints": {
194
  "chat": "POST /api/chat",
195
+ "health": "GET /api/health",
196
+ "reload": "POST /api/reload"
197
  }
198
  }
199
 
 
208
  return ChatResponse(
209
  response=response,
210
  status="success",
211
+ timestamp=time.time(),
212
+ model_type=yah_bot.model_type
213
  )
214
  except Exception as e:
215
  raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
 
219
  return HealthResponse(
220
  status="healthy",
221
  service="YAH Tech AI API",
222
+ timestamp=time.time(),
223
+ model_loaded=yah_bot.model is not None,
224
+ model_repo=yah_bot.repo_id,
225
+ model_type=yah_bot.model_type
226
  )
227
 
228
+ @app.post("/api/reload", response_model=ReloadResponse)
229
+ async def reload_model():
230
+ """
231
+ Manually reload the model from Hugging Face repo
232
+ Use this after updating your model in the repo
233
+ """
234
+ try:
235
+ success = yah_bot.reload_model()
236
+ if success:
237
+ return ReloadResponse(
238
+ status="success",
239
+ message="Model reloaded successfully from HF repo",
240
+ timestamp=time.time()
241
+ )
242
+ else:
243
+ return ReloadResponse(
244
+ status="error",
245
+ message="Failed to reload model",
246
+ timestamp=time.time()
247
+ )
248
+ except Exception as e:
249
+ raise HTTPException(status_code=500, detail=f"Error reloading model: {str(e)}")
250
+
251
  # For Hugging Face Spaces
252
  def get_app():
253
  return app