lahiruchamika27 commited on
Commit
7e844c6
·
verified ·
1 Parent(s): 903fc24

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -157
app.py CHANGED
@@ -3,12 +3,12 @@ import logging
3
  import sys
4
  import torch
5
  import tempfile
6
- from pathlib import Path
7
  from fastapi import FastAPI, HTTPException, Request
8
  from fastapi.responses import JSONResponse
9
  from pydantic import BaseModel
10
  from typing import List, Optional
11
  import uvicorn
 
12
 
13
  # Configure logging
14
  logging.basicConfig(
@@ -20,13 +20,12 @@ logger = logging.getLogger(__name__)
20
 
21
  app = FastAPI(title="Chat API", description="Simple chat API for Hugging Face Space")
22
 
23
- # Create a directory for caching in the current working directory
24
- cache_dir = Path("./model_cache")
25
- cache_dir.mkdir(exist_ok=True)
26
- os.environ["TRANSFORMERS_CACHE"] = str(cache_dir.absolute())
27
- os.environ["HF_HOME"] = str(cache_dir.absolute())
28
 
29
- logger.info(f"Using cache directory: {cache_dir.absolute()}")
30
 
31
  # Pydantic models for request/response
32
  class ChatTurn(BaseModel):
@@ -40,16 +39,7 @@ class ChatRequest(BaseModel):
40
  class ChatResponse(BaseModel):
41
  response: str
42
 
43
- # Global variables
44
- model = None
45
- tokenizer = None
46
- generator = None
47
- dataset = None
48
-
49
- # Load a small model or use a fallback if loading fails
50
- MODEL_ID = "distilgpt2" # Small model for testing
51
-
52
- # Fallback responses for when the model isn't available
53
  FALLBACK_RESPONSES = [
54
  "I apologize, but I'm currently having trouble processing your request.",
55
  "Sorry, I'm experiencing technical difficulties at the moment.",
@@ -67,162 +57,40 @@ async def generic_exception_handler(request: Request, exc: Exception):
67
  content={"detail": f"Internal server error: {str(exc)}"}
68
  )
69
 
70
- def try_load_model():
71
- """Attempt to load the model and tokenizer with appropriate error handling"""
72
- global model, tokenizer, generator
73
-
74
- try:
75
- # Import here to handle import errors gracefully
76
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
77
-
78
- logger.info(f"Loading tokenizer for {MODEL_ID}")
79
- tokenizer = AutoTokenizer.from_pretrained(
80
- MODEL_ID,
81
- cache_dir=cache_dir,
82
- local_files_only=False
83
- )
84
- logger.info("Tokenizer loaded successfully")
85
-
86
- logger.info(f"Loading model {MODEL_ID}")
87
- model = AutoModelForCausalLM.from_pretrained(
88
- MODEL_ID,
89
- cache_dir=cache_dir,
90
- local_files_only=False,
91
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
92
- low_cpu_mem_usage=True
93
- )
94
- logger.info("Model loaded successfully")
95
-
96
- device = 0 if torch.cuda.is_available() else -1
97
- logger.info(f"Creating generator pipeline (device: {device})")
98
- generator = pipeline(
99
- "text-generation",
100
- model=model,
101
- tokenizer=tokenizer,
102
- device=device
103
- )
104
- logger.info("Generator pipeline created successfully")
105
- return True
106
- except Exception as e:
107
- logger.error(f"Error loading model: {str(e)}", exc_info=True)
108
- return False
109
-
110
- def try_load_dataset():
111
- """Attempt to load the dataset with appropriate error handling"""
112
- global dataset
113
-
114
- try:
115
- from datasets import load_dataset
116
- logger.info("Loading dataset: lahiruchamika27/tia")
117
- dataset = load_dataset("lahiruchamika27/tia", cache_dir=cache_dir)
118
- logger.info("Dataset loaded successfully")
119
- return True
120
- except Exception as e:
121
- logger.error(f"Error loading dataset: {str(e)}", exc_info=True)
122
- return False
123
-
124
- # Startup event
125
- @app.on_event("startup")
126
- async def startup_event():
127
- logger.info("Starting application")
128
- # Try to load model but don't fail if it doesn't work
129
- model_loaded = try_load_model()
130
- dataset_loaded = try_load_dataset()
131
- logger.info(f"Startup complete. Model loaded: {model_loaded}, Dataset loaded: {dataset_loaded}")
132
-
133
  # Simple text-only route
134
  @app.get("/")
135
  async def root():
136
  return {"message": "Chat API is running. Use /api/chat for chat functionality."}
137
 
138
- # Chat endpoint
139
  @app.post("/api/chat", response_model=ChatResponse)
140
  async def chat(request: ChatRequest):
141
  logger.info(f"Received chat request: {request.message[:50]}...")
142
 
143
- # If the model isn't loaded, return a fallback response
144
- if generator is None:
145
- import random
146
- fallback = random.choice(FALLBACK_RESPONSES)
147
- logger.warning("Using fallback response because model is not loaded")
148
- return ChatResponse(response=fallback)
149
-
150
- try:
151
- # Format conversation history
152
- if request.history:
153
- full_prompt = ""
154
- for turn in request.history:
155
- if turn.user:
156
- full_prompt += f"User: {turn.user}\n"
157
- if turn.assistant:
158
- full_prompt += f"Assistant: {turn.assistant}\n"
159
-
160
- full_prompt += f"User: {request.message}\nAssistant:"
161
- else:
162
- full_prompt = f"User: {request.message}\nAssistant:"
163
-
164
- logger.info(f"Generated prompt: {full_prompt[:100]}...")
165
-
166
- # Generate text
167
- outputs = generator(
168
- full_prompt,
169
- max_new_tokens=100,
170
- temperature=0.7,
171
- top_p=0.9,
172
- do_sample=True
173
- )
174
-
175
- # Extract response
176
- generated_text = outputs[0]['generated_text']
177
- # Extract just the assistant's response
178
- response_text = generated_text[len(full_prompt):].strip()
179
-
180
- # Fallback if response is empty
181
- if not response_text or response_text.isspace():
182
- response_text = "I'm sorry, I'm having trouble generating a response right now."
183
-
184
- logger.info(f"Final response: {response_text[:50]}...")
185
- return ChatResponse(response=response_text)
186
 
187
- except Exception as e:
188
- logger.error(f"Error in chat endpoint: {str(e)}", exc_info=True)
189
- return ChatResponse(response="I'm sorry, I encountered an error while processing your request.")
190
-
191
- @app.get("/api/examples")
192
- async def get_examples(count: int = 5, split: str = "train"):
193
- if dataset is None:
194
- raise HTTPException(status_code=500, detail="Dataset not loaded")
195
 
196
- try:
197
- if split in dataset:
198
- # Convert dataset items to dict for easier JSON serialization
199
- examples = [dict(item) for item in dataset[split][:count]]
200
- return {"examples": examples}
201
- else:
202
- raise HTTPException(status_code=400, detail=f"Split '{split}' not found in dataset")
203
- except Exception as e:
204
- raise HTTPException(status_code=500, detail=str(e))
205
 
206
  @app.get("/health")
207
  async def health_check():
208
  return {
209
  "status": "ok",
210
- "model_loaded": model is not None,
211
- "tokenizer_loaded": tokenizer is not None,
212
- "generator_loaded": generator is not None,
213
- "dataset_loaded": dataset is not None,
214
- "model_name": MODEL_ID if model is not None else None,
215
- "device": "cuda" if torch.cuda.is_available() else "cpu",
216
- "cache_dir": str(cache_dir)
217
- }
218
-
219
- @app.get("/reload")
220
- async def reload_resources():
221
- model_loaded = try_load_model()
222
- dataset_loaded = try_load_dataset()
223
- return {
224
- "model_reloaded": model_loaded,
225
- "dataset_reloaded": dataset_loaded
226
  }
227
 
228
  if __name__ == "__main__":
 
3
  import sys
4
  import torch
5
  import tempfile
 
6
  from fastapi import FastAPI, HTTPException, Request
7
  from fastapi.responses import JSONResponse
8
  from pydantic import BaseModel
9
  from typing import List, Optional
10
  import uvicorn
11
+ import random
12
 
13
  # Configure logging
14
  logging.basicConfig(
 
20
 
21
  app = FastAPI(title="Chat API", description="Simple chat API for Hugging Face Space")
22
 
23
+ # Use the system's temporary directory which should be writable
24
+ temp_dir = tempfile.mkdtemp()
25
+ os.environ["TRANSFORMERS_CACHE"] = temp_dir
26
+ os.environ["HF_HOME"] = temp_dir
 
27
 
28
+ logger.info(f"Using temporary directory: {temp_dir}")
29
 
30
  # Pydantic models for request/response
31
  class ChatTurn(BaseModel):
 
39
  class ChatResponse(BaseModel):
40
  response: str
41
 
42
+ # Fallback responses
 
 
 
 
 
 
 
 
 
43
  FALLBACK_RESPONSES = [
44
  "I apologize, but I'm currently having trouble processing your request.",
45
  "Sorry, I'm experiencing technical difficulties at the moment.",
 
57
  content={"detail": f"Internal server error: {str(exc)}"}
58
  )
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  # Simple text-only route
61
  @app.get("/")
62
  async def root():
63
  return {"message": "Chat API is running. Use /api/chat for chat functionality."}
64
 
65
+ # Chat endpoint - just use fallback responses for now
66
  @app.post("/api/chat", response_model=ChatResponse)
67
  async def chat(request: ChatRequest):
68
  logger.info(f"Received chat request: {request.message[:50]}...")
69
 
70
+ # Select a random fallback response
71
+ fallback = random.choice(FALLBACK_RESPONSES)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
+ # Add a bit of personalization
74
+ if "hello" in request.message.lower() or "hi" in request.message.lower():
75
+ fallback = "Hello! " + fallback
76
+ elif "help" in request.message.lower():
77
+ fallback = "I'd like to help you with that, but " + fallback.lower()
 
 
 
78
 
79
+ logger.info(f"Returning fallback response")
80
+ return ChatResponse(response=fallback)
 
 
 
 
 
 
 
81
 
82
  @app.get("/health")
83
  async def health_check():
84
  return {
85
  "status": "ok",
86
+ "system_info": {
87
+ "device": "cpu", # No GPU for now
88
+ "temp_dir": temp_dir,
89
+ "pwd": os.getcwd(),
90
+ "user": os.getenv("USER", "unknown"),
91
+ "writable_temp": os.access(temp_dir, os.W_OK),
92
+ "writable_cwd": os.access(os.getcwd(), os.W_OK)
93
+ }
 
 
 
 
 
 
 
 
94
  }
95
 
96
  if __name__ == "__main__":