LucianStorm commited on
Commit
e4aff5c
·
verified ·
1 Parent(s): a4a53e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -41
app.py CHANGED
@@ -3,63 +3,44 @@ from pydantic import BaseModel
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
  import uvicorn
6
- import os
7
 
8
  app = FastAPI(title="TinyLlama Fitness Bot")
9
 
10
- print("Loading model and tokenizer...")
 
 
 
 
 
 
 
 
11
 
12
- # Initialize model and tokenizer globally
13
- try:
14
- model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
15
- tokenizer = AutoTokenizer.from_pretrained(model_name)
16
- model = AutoModelForCausalLM.from_pretrained(
17
- model_name,
18
- torch_dtype=torch.float32,
19
- low_cpu_mem_usage=True
20
- )
21
- print("Model and tokenizer loaded successfully!")
22
- MODEL_LOADED = True
23
- except Exception as e:
24
- print(f"Error loading model: {e}")
25
- MODEL_LOADED = False
26
 
27
  class Query(BaseModel):
28
  prompt: str
29
- max_length: int = 256
30
  temperature: float = 0.7
31
 
32
  class Response(BaseModel):
33
  response: str
34
 
35
- @app.get("/")
36
- def read_root():
37
- return {
38
- "status": "API is running!",
39
- "model_loaded": MODEL_LOADED
40
- }
41
-
42
- @app.get("/debug")
43
- def debug_info():
44
- return {
45
- "routes": [
46
- {"path": route.path, "name": route.name}
47
- for route in app.routes
48
- ],
49
- "model_loaded": MODEL_LOADED,
50
- "model_name": model_name if MODEL_LOADED else None,
51
- }
52
-
53
  @app.post("/chat")
54
  async def chat(query: Query):
55
- if not MODEL_LOADED:
56
- raise HTTPException(status_code=503, detail="Model not loaded")
57
-
58
  try:
59
- system_prompt = """You are a knowledgeable fitness and nutrition assistant."""
60
- formatted_prompt = f"<|system|>{system_prompt}</s><|user|>{query.prompt}</s><|assistant|>"
61
 
62
- inputs = tokenizer(formatted_prompt, return_tensors="pt")
 
 
 
 
 
 
63
 
64
  with torch.no_grad():
65
  outputs = model.generate(
@@ -68,9 +49,13 @@ async def chat(query: Query):
68
  temperature=query.temperature,
69
  top_p=0.9,
70
  do_sample=True,
 
 
 
71
  )
72
 
73
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
74
  response = response.split("<|assistant|>")[-1].strip()
75
 
76
  return Response(response=response)
@@ -78,5 +63,18 @@ async def chat(query: Query):
78
  except Exception as e:
79
  raise HTTPException(status_code=500, detail=str(e))
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  if __name__ == "__main__":
82
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
  import uvicorn
 
6
 
7
  app = FastAPI(title="TinyLlama Fitness Bot")
8
 
9
+ # Initialize model with optimizations
10
+ model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ model = AutoModelForCausalLM.from_pretrained(
13
+ model_name,
14
+ torch_dtype=torch.float32,
15
+ low_cpu_mem_usage=True,
16
+ device_map='auto'
17
+ )
18
 
19
+ # Enable model optimization
20
+ model.eval() # Set to evaluation mode
21
+ torch.backends.cudnn.benchmark = True # Enable CUDA optimization
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  class Query(BaseModel):
24
  prompt: str
25
+ max_length: int = 128 # Reduced max length
26
  temperature: float = 0.7
27
 
28
  class Response(BaseModel):
29
  response: str
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  @app.post("/chat")
32
  async def chat(query: Query):
 
 
 
33
  try:
34
+ # Simplified prompt template
35
+ formatted_prompt = f"<|user|>{query.prompt}</s><|assistant|>"
36
 
37
+ inputs = tokenizer(
38
+ formatted_prompt,
39
+ return_tensors="pt",
40
+ padding=True,
41
+ truncation=True,
42
+ max_length=query.max_length
43
+ )
44
 
45
  with torch.no_grad():
46
  outputs = model.generate(
 
49
  temperature=query.temperature,
50
  top_p=0.9,
51
  do_sample=True,
52
+ pad_token_id=tokenizer.eos_token_id,
53
+ num_return_sequences=1,
54
+ early_stopping=True
55
  )
56
 
57
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
58
+ # Clean up response
59
  response = response.split("<|assistant|>")[-1].strip()
60
 
61
  return Response(response=response)
 
63
  except Exception as e:
64
  raise HTTPException(status_code=500, detail=str(e))
65
 
66
+ # Health check endpoints
67
+ @app.get("/")
68
+ def read_root():
69
+ return {"status": "API is running!", "model_loaded": True}
70
+
71
+ @app.get("/debug")
72
+ def debug_info():
73
+ return {
74
+ "model_loaded": True,
75
+ "model_name": model_name,
76
+ "device": str(next(model.parameters()).device)
77
+ }
78
+
79
  if __name__ == "__main__":
80
  uvicorn.run(app, host="0.0.0.0", port=7860)