LucianStorm commited on
Commit
b59f9c5
·
verified ·
1 Parent(s): e1a117d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -39
app.py CHANGED
@@ -16,52 +16,76 @@ app.add_middleware(
16
  allow_headers=["*"],
17
  )
18
 
19
- print("Loading model and tokenizer...")
 
 
 
20
 
21
- try:
22
- model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
23
- tokenizer = AutoTokenizer.from_pretrained(model_name)
24
- model = AutoModelForCausalLM.from_pretrained(
25
- model_name,
26
- torch_dtype=torch.float16,
27
- low_cpu_mem_usage=True,
28
- device_map='auto'
29
- )
30
-
31
- model.eval()
32
- torch.backends.cudnn.benchmark = True
33
- print("Model loaded successfully!")
34
- MODEL_LOADED = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- except Exception as e:
37
- print(f"Error loading model: {e}")
38
- MODEL_LOADED = False
39
 
40
  class Query(BaseModel):
41
  prompt: str
42
- max_length: int = 150 # Increased for better responses
43
- temperature: float = 0.7 # Balanced temperature
44
 
45
  @app.post("/chat")
46
  async def chat(query: Query):
 
 
47
  if not MODEL_LOADED:
48
- raise HTTPException(status_code=503, detail="Model not loaded")
 
 
 
 
49
 
50
  try:
51
- # Better prompt template
52
- system_message = """You are a helpful fitness and nutrition assistant.
53
- Provide clear, informative answers to help users with their fitness goals.
54
- Be friendly but focused on giving practical advice."""
55
-
56
- formatted_prompt = f"<|system|>{system_message}</s><|user|>{query.prompt}</s><|assistant|>"
57
 
 
58
  inputs = tokenizer(
59
  formatted_prompt,
60
  return_tensors="pt",
61
  truncation=True,
62
- max_length=512 # Increased context window
63
- ).to(model.device)
64
 
 
65
  with torch.no_grad():
66
  outputs = model.generate(
67
  inputs["input_ids"],
@@ -70,36 +94,39 @@ async def chat(query: Query):
70
  top_p=0.9,
71
  do_sample=True,
72
  pad_token_id=tokenizer.eos_token_id,
73
- no_repeat_ngram_size=3, # Prevent repetition
74
- num_beams=1 # Keep generation fast
75
  )
76
 
77
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
78
- # Clean up response
79
  response = response.split("<|assistant|>")[-1].strip()
80
 
81
- # If response is too short, try to generate more
82
- if len(response.split()) < 5:
83
- return {"response": "I apologize, but could you please rephrase your question? I'll try to give a more helpful response."}
84
 
85
  return {"response": response}
86
-
87
  except Exception as e:
 
88
  raise HTTPException(status_code=500, detail=str(e))
89
 
90
  @app.get("/")
91
  def read_root():
92
  return {
93
  "status": "API is running!",
94
- "model_loaded": MODEL_LOADED
 
95
  }
96
 
97
  @app.get("/debug")
98
  def debug_info():
99
  return {
100
  "model_loaded": MODEL_LOADED,
101
- "model_name": model_name if MODEL_LOADED else None,
102
- "device": str(next(model.parameters()).device) if MODEL_LOADED else None
 
 
 
103
  }
104
 
105
  if __name__ == "__main__":
 
16
  allow_headers=["*"],
17
  )
18
 
19
+ # Global variables
20
+ model = None
21
+ tokenizer = None
22
+ MODEL_LOADED = False
23
 
24
+ def load_model():
25
+ global model, tokenizer, MODEL_LOADED
26
+ try:
27
+ print("Starting model load...")
28
+ model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
29
+
30
+ # CPU-specific settings
31
+ torch.set_num_threads(4) # Limit CPU threads
32
+
33
+ print("Loading tokenizer...")
34
+ tokenizer = AutoTokenizer.from_pretrained(
35
+ model_name,
36
+ local_files_only=False
37
+ )
38
+
39
+ print("Loading model...")
40
+ model = AutoModelForCausalLM.from_pretrained(
41
+ model_name,
42
+ torch_dtype=torch.float32, # Use float32 for CPU
43
+ low_cpu_mem_usage=True,
44
+ device_map=None # Force CPU
45
+ )
46
+
47
+ model.eval() # Set to evaluation mode
48
+ MODEL_LOADED = True
49
+ print("Model loaded successfully on CPU!")
50
+ return True
51
+ except Exception as e:
52
+ print(f"Error loading model: {str(e)}")
53
+ MODEL_LOADED = False
54
+ return False
55
 
56
+ # Load model on startup
57
+ print("Initiating model load...")
58
+ load_model()
59
 
60
  class Query(BaseModel):
61
  prompt: str
62
+ max_length: int = 100 # Reduced for CPU
63
+ temperature: float = 0.7
64
 
65
  @app.post("/chat")
66
  async def chat(query: Query):
67
+ global model, tokenizer, MODEL_LOADED
68
+
69
  if not MODEL_LOADED:
70
+ if not load_model():
71
+ raise HTTPException(
72
+ status_code=503,
73
+ detail="Model is not loaded. Please try again in a minute."
74
+ )
75
 
76
  try:
77
+ # Simpler prompt template for efficiency
78
+ formatted_prompt = f"<|user|>{query.prompt}</s><|assistant|>"
 
 
 
 
79
 
80
+ # Tokenize with smaller context
81
  inputs = tokenizer(
82
  formatted_prompt,
83
  return_tensors="pt",
84
  truncation=True,
85
+ max_length=256 # Reduced context window for CPU
86
+ )
87
 
88
+ # Generate with CPU-optimized settings
89
  with torch.no_grad():
90
  outputs = model.generate(
91
  inputs["input_ids"],
 
94
  top_p=0.9,
95
  do_sample=True,
96
  pad_token_id=tokenizer.eos_token_id,
97
+ num_beams=1, # No beam search for speed
98
+ early_stopping=True
99
  )
100
 
101
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
102
  response = response.split("<|assistant|>")[-1].strip()
103
 
104
+ if not response or len(response.split()) < 3:
105
+ return {"response": "I apologize, could you please rephrase your question?"}
 
106
 
107
  return {"response": response}
108
+
109
  except Exception as e:
110
+ print(f"Error during generation: {str(e)}")
111
  raise HTTPException(status_code=500, detail=str(e))
112
 
113
  @app.get("/")
114
  def read_root():
115
  return {
116
  "status": "API is running!",
117
+ "model_loaded": MODEL_LOADED,
118
+ "backend": "CPU"
119
  }
120
 
121
  @app.get("/debug")
122
  def debug_info():
123
  return {
124
  "model_loaded": MODEL_LOADED,
125
+ "device": "cpu",
126
+ "num_threads": torch.get_num_threads(),
127
+ "memory_info": {
128
+ "max_memory": f"{torch.cuda.max_memory_allocated() / 1024**2:.2f}MB" if torch.cuda.is_available() else "CPU only"
129
+ }
130
  }
131
 
132
  if __name__ == "__main__":