lahiruchamika27 commited on
Commit
9db031d
·
verified ·
1 Parent(s): 84efcee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -37
app.py CHANGED
@@ -1,17 +1,30 @@
1
  import os
 
 
2
  import torch
3
- from fastapi import FastAPI, HTTPException
 
4
  from pydantic import BaseModel
5
- from typing import List, Dict, Optional
6
  from datasets import load_dataset
7
- from transformers import AutoModelForCausalLM, AutoTokenizer
8
  import uvicorn
 
 
 
 
 
 
 
 
 
9
 
10
  app = FastAPI()
11
 
12
  # Global variables
13
  model = None
14
  tokenizer = None
 
15
  dataset = None
16
 
17
  # Pydantic models for request/response
@@ -26,35 +39,67 @@ class ChatRequest(BaseModel):
26
  class ChatResponse(BaseModel):
27
  response: str
28
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  # Load model on startup
30
  @app.on_event("startup")
31
  async def startup_event():
32
- global model, tokenizer, dataset
 
33
  try:
34
- # Load the model and tokenizer
35
- model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
36
- tokenizer = AutoTokenizer.from_pretrained(model_id)
 
 
 
37
 
 
38
  model = AutoModelForCausalLM.from_pretrained(
39
- model_id,
40
- torch_dtype=torch.float16,
41
- device_map="auto"
 
42
  )
 
 
 
 
 
 
43
 
44
- # Load dataset
45
- dataset = load_dataset("lahiruchamika27/tia")
46
- print("Model, tokenizer, and dataset loaded successfully!")
 
 
 
 
 
 
 
47
  except Exception as e:
48
- print(f"Error loading model: {str(e)}")
49
- # Continue without failing - we'll handle errors in the endpoints
50
 
51
  @app.post("/api/chat", response_model=ChatResponse)
52
  async def chat(request: ChatRequest):
53
- global model, tokenizer
54
 
55
- # Ensure model is loaded
56
- if model is None or tokenizer is None:
57
- raise HTTPException(status_code=500, detail="Model or tokenizer not loaded")
 
58
 
59
  try:
60
  # Format conversation
@@ -70,30 +115,37 @@ async def chat(request: ChatRequest):
70
  else:
71
  full_prompt = f"User: {request.message}\nAssistant:"
72
 
73
- # Tokenize and generate
74
- inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
75
 
76
- with torch.no_grad():
77
- outputs = model.generate(
78
- inputs["input_ids"],
79
- max_new_tokens=512,
80
- temperature=0.7,
81
- top_p=0.9,
82
- do_sample=True
83
- )
 
 
84
 
85
- # Decode the output
86
- response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
 
 
87
 
88
- return ChatResponse(response=response.strip())
 
 
 
 
 
89
 
90
  except Exception as e:
91
- raise HTTPException(status_code=500, detail=str(e))
 
92
 
93
  @app.get("/api/examples")
94
  async def get_examples(count: int = 5, split: str = "train"):
95
- global dataset
96
-
97
  if dataset is None:
98
  raise HTTPException(status_code=500, detail="Dataset not loaded")
99
 
@@ -104,13 +156,23 @@ async def get_examples(count: int = 5, split: str = "train"):
104
  return {"examples": examples}
105
  else:
106
  raise HTTPException(status_code=400, detail=f"Split '{split}' not found in dataset")
107
-
108
  except Exception as e:
109
  raise HTTPException(status_code=500, detail=str(e))
110
 
111
  @app.get("/health")
112
  async def health_check():
113
- return {"status": "ok", "model_loaded": model is not None, "tokenizer_loaded": tokenizer is not None}
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  if __name__ == "__main__":
116
  port = int(os.environ.get("PORT", 7860))
 
1
  import os
2
+ import logging
3
+ import sys
4
  import torch
5
+ from fastapi import FastAPI, HTTPException, Request
6
+ from fastapi.responses import JSONResponse
7
  from pydantic import BaseModel
8
+ from typing import List, Dict, Optional, Any
9
  from datasets import load_dataset
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
11
  import uvicorn
12
+ import time
13
+
14
+ # Configure logging
15
+ logging.basicConfig(
16
+ level=logging.INFO,
17
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
18
+ handlers=[logging.StreamHandler(sys.stdout)]
19
+ )
20
+ logger = logging.getLogger(__name__)
21
 
22
  app = FastAPI()
23
 
24
  # Global variables
25
  model = None
26
  tokenizer = None
27
+ generator = None
28
  dataset = None
29
 
30
  # Pydantic models for request/response
 
39
  class ChatResponse(BaseModel):
40
  response: str
41
 
42
+ # Use a much smaller model suitable for Hugging Face Spaces
43
+ MODEL_ID = "distilgpt2" # Using a very small model for testing
44
+
45
+ # Error handler
46
+ @app.exception_handler(Exception)
47
+ async def generic_exception_handler(request: Request, exc: Exception):
48
+ logger.error(f"Unhandled exception: {str(exc)}", exc_info=True)
49
+ return JSONResponse(
50
+ status_code=500,
51
+ content={"detail": f"Internal server error: {str(exc)}"}
52
+ )
53
+
54
  # Load model on startup
55
  @app.on_event("startup")
56
  async def startup_event():
57
+ global model, tokenizer, generator, dataset
58
+
59
  try:
60
+ logger.info(f"Loading model: {MODEL_ID}")
61
+ start_time = time.time()
62
+
63
+ # Load the tokenizer
64
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
65
+ logger.info(f"Tokenizer loaded in {time.time() - start_time:.2f} seconds")
66
 
67
+ # Load the model with optimizations
68
  model = AutoModelForCausalLM.from_pretrained(
69
+ MODEL_ID,
70
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
71
+ low_cpu_mem_usage=True,
72
+ device_map="auto" if torch.cuda.is_available() else None
73
  )
74
+ logger.info(f"Model loaded in {time.time() - start_time:.2f} seconds")
75
+
76
+ # Create a text generation pipeline
77
+ device = 0 if torch.cuda.is_available() else -1
78
+ generator = pipeline("text-generation", model=model, tokenizer=tokenizer, device=device)
79
+ logger.info(f"Generator pipeline created in {time.time() - start_time:.2f} seconds")
80
 
81
+ # Try to load dataset
82
+ try:
83
+ logger.info("Loading dataset: lahiruchamika27/tia")
84
+ dataset = load_dataset("lahiruchamika27/tia")
85
+ logger.info("Dataset loaded successfully")
86
+ except Exception as e:
87
+ logger.error(f"Error loading dataset: {str(e)}")
88
+ logger.info("Continuing without dataset")
89
+
90
+ logger.info(f"Startup completed in {time.time() - start_time:.2f} seconds")
91
  except Exception as e:
92
+ logger.error(f"Error during startup: {str(e)}", exc_info=True)
93
+ logger.info("API will still be available but might not function correctly")
94
 
95
  @app.post("/api/chat", response_model=ChatResponse)
96
  async def chat(request: ChatRequest):
97
+ logger.info(f"Received chat request: {request.message[:50]}...")
98
 
99
+ # Check if model is loaded
100
+ if generator is None:
101
+ logger.error("Text generator not initialized")
102
+ raise HTTPException(status_code=500, detail="Text generation pipeline not initialized")
103
 
104
  try:
105
  # Format conversation
 
115
  else:
116
  full_prompt = f"User: {request.message}\nAssistant:"
117
 
118
+ logger.info(f"Generated prompt: {full_prompt[:100]}...")
 
119
 
120
+ # Generate response
121
+ start_time = time.time()
122
+ outputs = generator(
123
+ full_prompt,
124
+ max_new_tokens=100,
125
+ temperature=0.7,
126
+ top_p=0.9,
127
+ do_sample=True
128
+ )
129
+ logger.info(f"Text generated in {time.time() - start_time:.2f} seconds")
130
 
131
+ # Extract response
132
+ generated_text = outputs[0]['generated_text']
133
+ # Extract only the assistant's response
134
+ response_text = generated_text[len(full_prompt):].strip()
135
 
136
+ # If empty or just whitespace, return a fallback message
137
+ if not response_text or response_text.isspace():
138
+ response_text = "I'm sorry, I'm having trouble generating a response right now."
139
+
140
+ logger.info(f"Final response: {response_text[:50]}...")
141
+ return ChatResponse(response=response_text)
142
 
143
  except Exception as e:
144
+ logger.error(f"Error generating response: {str(e)}", exc_info=True)
145
+ raise HTTPException(status_code=500, detail=f"Error generating response: {str(e)}")
146
 
147
  @app.get("/api/examples")
148
  async def get_examples(count: int = 5, split: str = "train"):
 
 
149
  if dataset is None:
150
  raise HTTPException(status_code=500, detail="Dataset not loaded")
151
 
 
156
  return {"examples": examples}
157
  else:
158
  raise HTTPException(status_code=400, detail=f"Split '{split}' not found in dataset")
 
159
  except Exception as e:
160
  raise HTTPException(status_code=500, detail=str(e))
161
 
162
  @app.get("/health")
163
  async def health_check():
164
+ system_info = {
165
+ "status": "ok",
166
+ "model_loaded": model is not None,
167
+ "tokenizer_loaded": tokenizer is not None,
168
+ "generator_loaded": generator is not None,
169
+ "dataset_loaded": dataset is not None,
170
+ "model_name": MODEL_ID,
171
+ "torch_device": "cuda" if torch.cuda.is_available() else "cpu",
172
+ "cuda_available": torch.cuda.is_available(),
173
+ "cuda_device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0
174
+ }
175
+ return system_info
176
 
177
  if __name__ == "__main__":
178
  port = int(os.environ.get("PORT", 7860))