lahiruchamika27 commited on
Commit
3562eea
·
verified ·
1 Parent(s): 10f5e5a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -58
app.py CHANGED
@@ -2,14 +2,13 @@ import os
2
  import logging
3
  import sys
4
  import torch
 
 
5
  from fastapi import FastAPI, HTTPException, Request
6
  from fastapi.responses import JSONResponse
7
  from pydantic import BaseModel
8
- from typing import List, Dict, Optional, Any
9
- from datasets import load_dataset
10
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
11
  import uvicorn
12
- import time
13
 
14
  # Configure logging
15
  logging.basicConfig(
@@ -19,13 +18,15 @@ logging.basicConfig(
19
  )
20
  logger = logging.getLogger(__name__)
21
 
22
- app = FastAPI()
23
 
24
- # Global variables
25
- model = None
26
- tokenizer = None
27
- generator = None
28
- dataset = None
 
 
29
 
30
  # Pydantic models for request/response
31
  class ChatTurn(BaseModel):
@@ -39,8 +40,23 @@ class ChatRequest(BaseModel):
39
  class ChatResponse(BaseModel):
40
  response: str
41
 
42
- # Use a much smaller model suitable for Hugging Face Spaces
43
- MODEL_ID = "distilgpt2" # Using a very small model for testing
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  # Error handler
46
  @app.exception_handler(Exception)
@@ -51,58 +67,88 @@ async def generic_exception_handler(request: Request, exc: Exception):
51
  content={"detail": f"Internal server error: {str(exc)}"}
52
  )
53
 
54
- # Load model on startup
55
- @app.on_event("startup")
56
- async def startup_event():
57
- global model, tokenizer, generator, dataset
58
 
59
  try:
60
- logger.info(f"Loading model: {MODEL_ID}")
61
- start_time = time.time()
62
 
63
- # Load the tokenizer
64
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
65
- logger.info(f"Tokenizer loaded in {time.time() - start_time:.2f} seconds")
 
 
 
 
66
 
67
- # Load the model with optimizations
68
  model = AutoModelForCausalLM.from_pretrained(
69
  MODEL_ID,
 
 
70
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
71
- low_cpu_mem_usage=True,
72
- device_map="auto" if torch.cuda.is_available() else None
73
  )
74
- logger.info(f"Model loaded in {time.time() - start_time:.2f} seconds")
75
 
76
- # Create a text generation pipeline
77
  device = 0 if torch.cuda.is_available() else -1
78
- generator = pipeline("text-generation", model=model, tokenizer=tokenizer, device=device)
79
- logger.info(f"Generator pipeline created in {time.time() - start_time:.2f} seconds")
80
-
81
- # Try to load dataset
82
- try:
83
- logger.info("Loading dataset: lahiruchamika27/tia")
84
- dataset = load_dataset("lahiruchamika27/tia")
85
- logger.info("Dataset loaded successfully")
86
- except Exception as e:
87
- logger.error(f"Error loading dataset: {str(e)}")
88
- logger.info("Continuing without dataset")
89
-
90
- logger.info(f"Startup completed in {time.time() - start_time:.2f} seconds")
91
  except Exception as e:
92
- logger.error(f"Error during startup: {str(e)}", exc_info=True)
93
- logger.info("API will still be available but might not function correctly")
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  @app.post("/api/chat", response_model=ChatResponse)
96
  async def chat(request: ChatRequest):
97
  logger.info(f"Received chat request: {request.message[:50]}...")
98
 
99
- # Check if model is loaded
100
  if generator is None:
101
- logger.error("Text generator not initialized")
102
- raise HTTPException(status_code=500, detail="Text generation pipeline not initialized")
 
 
103
 
104
  try:
105
- # Format conversation
106
  if request.history:
107
  full_prompt = ""
108
  for turn in request.history:
@@ -117,8 +163,7 @@ async def chat(request: ChatRequest):
117
 
118
  logger.info(f"Generated prompt: {full_prompt[:100]}...")
119
 
120
- # Generate response
121
- start_time = time.time()
122
  outputs = generator(
123
  full_prompt,
124
  max_new_tokens=100,
@@ -126,14 +171,13 @@ async def chat(request: ChatRequest):
126
  top_p=0.9,
127
  do_sample=True
128
  )
129
- logger.info(f"Text generated in {time.time() - start_time:.2f} seconds")
130
 
131
  # Extract response
132
  generated_text = outputs[0]['generated_text']
133
- # Extract only the assistant's response
134
  response_text = generated_text[len(full_prompt):].strip()
135
 
136
- # If empty or just whitespace, return a fallback message
137
  if not response_text or response_text.isspace():
138
  response_text = "I'm sorry, I'm having trouble generating a response right now."
139
 
@@ -141,8 +185,8 @@ async def chat(request: ChatRequest):
141
  return ChatResponse(response=response_text)
142
 
143
  except Exception as e:
144
- logger.error(f"Error generating response: {str(e)}", exc_info=True)
145
- raise HTTPException(status_code=500, detail=f"Error generating response: {str(e)}")
146
 
147
  @app.get("/api/examples")
148
  async def get_examples(count: int = 5, split: str = "train"):
@@ -161,18 +205,25 @@ async def get_examples(count: int = 5, split: str = "train"):
161
 
162
  @app.get("/health")
163
  async def health_check():
164
- system_info = {
165
  "status": "ok",
166
  "model_loaded": model is not None,
167
  "tokenizer_loaded": tokenizer is not None,
168
  "generator_loaded": generator is not None,
169
  "dataset_loaded": dataset is not None,
170
- "model_name": MODEL_ID,
171
- "torch_device": "cuda" if torch.cuda.is_available() else "cpu",
172
- "cuda_available": torch.cuda.is_available(),
173
- "cuda_device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0
 
 
 
 
 
 
 
 
174
  }
175
- return system_info
176
 
177
  if __name__ == "__main__":
178
  port = int(os.environ.get("PORT", 7860))
 
2
  import logging
3
  import sys
4
  import torch
5
+ import tempfile
6
+ from pathlib import Path
7
  from fastapi import FastAPI, HTTPException, Request
8
  from fastapi.responses import JSONResponse
9
  from pydantic import BaseModel
10
+ from typing import List, Optional
 
 
11
  import uvicorn
 
12
 
13
  # Configure logging
14
  logging.basicConfig(
 
18
  )
19
  logger = logging.getLogger(__name__)
20
 
21
+ app = FastAPI(title="Chat API", description="Simple chat API for Hugging Face Space")
22
 
23
+ # Create a directory for caching in the current working directory
24
+ cache_dir = Path("./model_cache")
25
+ cache_dir.mkdir(exist_ok=True)
26
+ os.environ["TRANSFORMERS_CACHE"] = str(cache_dir.absolute())
27
+ os.environ["HF_HOME"] = str(cache_dir.absolute())
28
+
29
+ logger.info(f"Using cache directory: {cache_dir.absolute()}")
30
 
31
  # Pydantic models for request/response
32
  class ChatTurn(BaseModel):
 
40
  class ChatResponse(BaseModel):
41
  response: str
42
 
43
+ # Global variables
44
+ model = None
45
+ tokenizer = None
46
+ generator = None
47
+ dataset = None
48
+
49
+ # Load a small model or use a fallback if loading fails
50
+ MODEL_ID = "distilgpt2" # Small model for testing
51
+
52
+ # Fallback responses for when the model isn't available
53
+ FALLBACK_RESPONSES = [
54
+ "I apologize, but I'm currently having trouble processing your request.",
55
+ "Sorry, I'm experiencing technical difficulties at the moment.",
56
+ "I'm unable to generate a proper response right now. Please try again later.",
57
+ "My language model is temporarily unavailable. Please check back soon.",
58
+ "I would like to help, but I'm having some technical issues. Please try again shortly."
59
+ ]
60
 
61
  # Error handler
62
  @app.exception_handler(Exception)
 
67
  content={"detail": f"Internal server error: {str(exc)}"}
68
  )
69
 
70
+ def try_load_model():
71
+ """Attempt to load the model and tokenizer with appropriate error handling"""
72
+ global model, tokenizer, generator
 
73
 
74
  try:
75
+ # Import here to handle import errors gracefully
76
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
77
 
78
+ logger.info(f"Loading tokenizer for {MODEL_ID}")
79
+ tokenizer = AutoTokenizer.from_pretrained(
80
+ MODEL_ID,
81
+ cache_dir=cache_dir,
82
+ local_files_only=False
83
+ )
84
+ logger.info("Tokenizer loaded successfully")
85
 
86
+ logger.info(f"Loading model {MODEL_ID}")
87
  model = AutoModelForCausalLM.from_pretrained(
88
  MODEL_ID,
89
+ cache_dir=cache_dir,
90
+ local_files_only=False,
91
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
92
+ low_cpu_mem_usage=True
 
93
  )
94
+ logger.info("Model loaded successfully")
95
 
 
96
  device = 0 if torch.cuda.is_available() else -1
97
+ logger.info(f"Creating generator pipeline (device: {device})")
98
+ generator = pipeline(
99
+ "text-generation",
100
+ model=model,
101
+ tokenizer=tokenizer,
102
+ device=device
103
+ )
104
+ logger.info("Generator pipeline created successfully")
105
+ return True
 
 
 
 
106
  except Exception as e:
107
+ logger.error(f"Error loading model: {str(e)}", exc_info=True)
108
+ return False
109
 
110
+ def try_load_dataset():
111
+ """Attempt to load the dataset with appropriate error handling"""
112
+ global dataset
113
+
114
+ try:
115
+ from datasets import load_dataset
116
+ logger.info("Loading dataset: lahiruchamika27/tia")
117
+ dataset = load_dataset("lahiruchamika27/tia", cache_dir=cache_dir)
118
+ logger.info("Dataset loaded successfully")
119
+ return True
120
+ except Exception as e:
121
+ logger.error(f"Error loading dataset: {str(e)}", exc_info=True)
122
+ return False
123
+
124
+ # Startup event
125
+ @app.on_event("startup")
126
+ async def startup_event():
127
+ logger.info("Starting application")
128
+ # Try to load model but don't fail if it doesn't work
129
+ model_loaded = try_load_model()
130
+ dataset_loaded = try_load_dataset()
131
+ logger.info(f"Startup complete. Model loaded: {model_loaded}, Dataset loaded: {dataset_loaded}")
132
+
133
+ # Simple text-only route
134
+ @app.get("/")
135
+ async def root():
136
+ return {"message": "Chat API is running. Use /api/chat for chat functionality."}
137
+
138
+ # Chat endpoint
139
  @app.post("/api/chat", response_model=ChatResponse)
140
  async def chat(request: ChatRequest):
141
  logger.info(f"Received chat request: {request.message[:50]}...")
142
 
143
+ # If the model isn't loaded, return a fallback response
144
  if generator is None:
145
+ import random
146
+ fallback = random.choice(FALLBACK_RESPONSES)
147
+ logger.warning("Using fallback response because model is not loaded")
148
+ return ChatResponse(response=fallback)
149
 
150
  try:
151
+ # Format conversation history
152
  if request.history:
153
  full_prompt = ""
154
  for turn in request.history:
 
163
 
164
  logger.info(f"Generated prompt: {full_prompt[:100]}...")
165
 
166
+ # Generate text
 
167
  outputs = generator(
168
  full_prompt,
169
  max_new_tokens=100,
 
171
  top_p=0.9,
172
  do_sample=True
173
  )
 
174
 
175
  # Extract response
176
  generated_text = outputs[0]['generated_text']
177
+ # Extract just the assistant's response
178
  response_text = generated_text[len(full_prompt):].strip()
179
 
180
+ # Fallback if response is empty
181
  if not response_text or response_text.isspace():
182
  response_text = "I'm sorry, I'm having trouble generating a response right now."
183
 
 
185
  return ChatResponse(response=response_text)
186
 
187
  except Exception as e:
188
+ logger.error(f"Error in chat endpoint: {str(e)}", exc_info=True)
189
+ return ChatResponse(response="I'm sorry, I encountered an error while processing your request.")
190
 
191
  @app.get("/api/examples")
192
  async def get_examples(count: int = 5, split: str = "train"):
 
205
 
206
  @app.get("/health")
207
  async def health_check():
208
+ return {
209
  "status": "ok",
210
  "model_loaded": model is not None,
211
  "tokenizer_loaded": tokenizer is not None,
212
  "generator_loaded": generator is not None,
213
  "dataset_loaded": dataset is not None,
214
+ "model_name": MODEL_ID if model is not None else None,
215
+ "device": "cuda" if torch.cuda.is_available() else "cpu",
216
+ "cache_dir": str(cache_dir)
217
+ }
218
+
219
+ @app.get("/reload")
220
+ async def reload_resources():
221
+ model_loaded = try_load_model()
222
+ dataset_loaded = try_load_dataset()
223
+ return {
224
+ "model_reloaded": model_loaded,
225
+ "dataset_reloaded": dataset_loaded
226
  }
 
227
 
228
  if __name__ == "__main__":
229
  port = int(os.environ.get("PORT", 7860))