LucianStorm commited on
Commit
9f31314
·
verified ·
1 Parent(s): e4aff5c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -38
app.py CHANGED
@@ -1,47 +1,99 @@
1
  from fastapi import FastAPI, HTTPException
 
2
  from pydantic import BaseModel
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
  import uvicorn
 
6
 
7
  app = FastAPI(title="TinyLlama Fitness Bot")
8
 
9
- # Initialize model with optimizations
10
- model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
11
- tokenizer = AutoTokenizer.from_pretrained(model_name)
12
- model = AutoModelForCausalLM.from_pretrained(
13
- model_name,
14
- torch_dtype=torch.float32,
15
- low_cpu_mem_usage=True,
16
- device_map='auto'
17
  )
18
 
19
- # Enable model optimization
20
- model.eval() # Set to evaluation mode
21
- torch.backends.cudnn.benchmark = True # Enable CUDA optimization
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  class Query(BaseModel):
24
  prompt: str
25
- max_length: int = 128 # Reduced max length
26
- temperature: float = 0.7
27
 
28
  class Response(BaseModel):
29
  response: str
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  @app.post("/chat")
32
  async def chat(query: Query):
 
 
 
33
  try:
34
- # Simplified prompt template
35
- formatted_prompt = f"<|user|>{query.prompt}</s><|assistant|>"
 
36
 
 
37
  inputs = tokenizer(
38
- formatted_prompt,
39
  return_tensors="pt",
40
- padding=True,
41
  truncation=True,
42
- max_length=query.max_length
43
- )
44
 
 
45
  with torch.no_grad():
46
  outputs = model.generate(
47
  inputs["input_ids"],
@@ -49,32 +101,24 @@ async def chat(query: Query):
49
  temperature=query.temperature,
50
  top_p=0.9,
51
  do_sample=True,
52
- pad_token_id=tokenizer.eos_token_id,
53
- num_return_sequences=1,
54
- early_stopping=True
55
  )
56
 
 
57
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
58
- # Clean up response
59
- response = response.split("<|assistant|>")[-1].strip()
60
 
61
  return Response(response=response)
62
 
63
  except Exception as e:
64
  raise HTTPException(status_code=500, detail=str(e))
65
 
66
- # Health check endpoints
67
- @app.get("/")
68
- def read_root():
69
- return {"status": "API is running!", "model_loaded": True}
70
-
71
- @app.get("/debug")
72
- def debug_info():
73
- return {
74
- "model_loaded": True,
75
- "model_name": model_name,
76
- "device": str(next(model.parameters()).device)
77
- }
78
-
79
  if __name__ == "__main__":
80
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
1
  from fastapi import FastAPI, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
  import torch
6
  import uvicorn
7
+ import os
8
 
9
  app = FastAPI(title="TinyLlama Fitness Bot")
10
 
11
+ # Add CORS middleware
12
+ app.add_middleware(
13
+ CORSMiddleware,
14
+ allow_origins=["*"],
15
+ allow_credentials=True,
16
+ allow_methods=["*"],
17
+ allow_headers=["*"],
 
18
  )
19
 
20
+ # Set environment variables for cache
21
+ os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
22
+ os.environ['TORCH_HOME'] = '/tmp/torch_cache'
23
+
24
+ print("Loading model and tokenizer...")
25
+
26
+ try:
27
+ # Load model with maximum optimization
28
+ model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
29
+ tokenizer = AutoTokenizer.from_pretrained(
30
+ model_name,
31
+ cache_dir='/tmp/transformers_cache'
32
+ )
33
+ model = AutoModelForCausalLM.from_pretrained(
34
+ model_name,
35
+ torch_dtype=torch.float16, # Use float16 for faster inference
36
+ low_cpu_mem_usage=True,
37
+ device_map='auto',
38
+ cache_dir='/tmp/transformers_cache'
39
+ )
40
+
41
+ # Enable fast mode
42
+ model.eval()
43
+ torch.backends.cudnn.benchmark = True
44
+ print("Model loaded successfully!")
45
+ MODEL_LOADED = True
46
+
47
+ except Exception as e:
48
+ print(f"Error loading model: {e}")
49
+ MODEL_LOADED = False
50
 
51
  class Query(BaseModel):
52
  prompt: str
53
+ max_length: int = 50 # Very short responses
54
+ temperature: float = 0.8 # Higher temperature for faster responses
55
 
56
  class Response(BaseModel):
57
  response: str
58
 
59
+ @app.get("/")
60
+ def read_root():
61
+ return {
62
+ "status": "API is running!",
63
+ "model_loaded": MODEL_LOADED
64
+ }
65
+
66
+ @app.get("/debug")
67
+ def debug_info():
68
+ return {
69
+ "model_loaded": MODEL_LOADED,
70
+ "model_name": model_name if MODEL_LOADED else None,
71
+ "device": str(next(model.parameters()).device) if MODEL_LOADED else None,
72
+ "routes": [
73
+ {"path": route.path, "name": route.name}
74
+ for route in app.routes
75
+ ]
76
+ }
77
+
78
  @app.post("/chat")
79
  async def chat(query: Query):
80
+ if not MODEL_LOADED:
81
+ raise HTTPException(status_code=503, detail="Model not loaded")
82
+
83
  try:
84
+ # Create fitness-focused prompt
85
+ system_message = "You are a helpful fitness assistant. Provide short, clear answers."
86
+ formatted_prompt = f"{system_message}\nUser: {query.prompt}\nAssistant:"
87
 
88
+ # Tokenize with truncation
89
  inputs = tokenizer(
90
+ formatted_prompt,
91
  return_tensors="pt",
 
92
  truncation=True,
93
+ max_length=32
94
+ ).to(model.device)
95
 
96
+ # Generate response
97
  with torch.no_grad():
98
  outputs = model.generate(
99
  inputs["input_ids"],
 
101
  temperature=query.temperature,
102
  top_p=0.9,
103
  do_sample=True,
104
+ num_beams=1, # No beam search
105
+ early_stopping=True,
106
+ pad_token_id=tokenizer.eos_token_id
107
  )
108
 
109
+ # Decode and clean response
110
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
111
+ response = response.split("Assistant:")[-1].strip()
 
112
 
113
  return Response(response=response)
114
 
115
  except Exception as e:
116
  raise HTTPException(status_code=500, detail=str(e))
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  if __name__ == "__main__":
119
+ uvicorn.run(
120
+ "app:app",
121
+ host="0.0.0.0",
122
+ port=7860,
123
+ workers=1
124
+ )