LucianStorm commited on
Commit
a542700
·
verified ·
1 Parent(s): 485b23d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -73
app.py CHANGED
@@ -11,13 +11,11 @@ os.environ['TORCH_HOME'] = '/tmp/torch_cache'
11
 
12
  app = FastAPI(title="DIANA - Diet And Nutrition Assistant")
13
 
14
- app.add_middleware(
15
- CORSMiddleware,
16
- allow_origins=["*"],
17
- allow_credentials=True,
18
- allow_methods=["*"],
19
- allow_headers=["*"],
20
- )
21
 
22
  model = None
23
  tokenizer = None
@@ -28,11 +26,11 @@ def load_model():
28
  try:
29
  print("Starting model load...")
30
  model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
31
- torch.set_num_threads(4)
32
 
33
  tokenizer = AutoTokenizer.from_pretrained(
34
  model_name,
35
- cache_dir='/tmp/transformers_cache'
 
36
  )
37
 
38
  model = AutoModelForCausalLM.from_pretrained(
@@ -41,7 +39,7 @@ def load_model():
41
  low_cpu_mem_usage=True,
42
  device_map=None,
43
  cache_dir='/tmp/transformers_cache'
44
- )
45
 
46
  model.eval()
47
  MODEL_LOADED = True
@@ -56,106 +54,104 @@ load_model()
56
 
57
  class Query(BaseModel):
58
  prompt: str
59
- max_length: int = 200
60
  temperature: float = 0.7
61
 
62
- def is_greeting(text):
63
- greetings = ['hi', 'hello', 'hey', 'good morning', 'good afternoon', 'good evening', 'greetings']
64
- return any(greeting in text.lower() for greeting in greetings)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- def is_fitness_question(text):
67
- fitness_keywords = [
68
- 'workout', 'exercise', 'training', 'muscle', 'strength', 'cardio', 'weight',
69
- 'diet', 'nutrition', 'protein', 'carbs', 'fat', 'meal', 'food', 'eating',
70
- 'routine', 'program', 'sets', 'reps', 'gym', 'fitness', 'health'
71
- ]
72
- return any(keyword in text.lower() for keyword in fitness_keywords)
73
 
74
  @app.post("/chat")
75
  async def chat(query: Query):
76
  if not MODEL_LOADED:
77
- if not load_model():
78
- raise HTTPException(
79
- status_code=503,
80
- detail="DIANA is still initializing. Please try again in a minute."
81
- )
82
 
83
  try:
84
- # Personalized system prompts
85
  if is_greeting(query.prompt):
86
- system_prompt = """You are DIANA (Diet And Nutrition Assistant), a friendly and knowledgeable
87
- fitness companion. Always respond warmly and offer to help with fitness and nutrition guidance.
88
- Sign your responses with '- DIANA 💪'"""
89
- else:
90
- system_prompt = """You are DIANA (Diet And Nutrition Assistant), a knowledgeable fitness and
91
- nutrition guide. Provide practical, safe, and evidence-based advice about workouts, nutrition,
92
- and healthy living. Include:
93
- 1. Clear, actionable recommendations
94
- 2. Safety considerations
95
- 3. Beginner-friendly explanations
96
- Remember to sign your responses with '- DIANA 💪'"""
97
 
98
- formatted_prompt = f"""<|system|>{system_prompt}</s>
99
- <|user|>{query.prompt}</s>
100
- <|assistant|>"""
 
 
 
 
 
 
 
101
 
102
  inputs = tokenizer(
103
  formatted_prompt,
104
  return_tensors="pt",
105
  truncation=True,
106
- max_length=300
107
- )
 
108
 
109
- with torch.no_grad():
110
  outputs = model.generate(
111
  inputs["input_ids"],
112
- max_new_tokens=200,
113
- min_new_tokens=50,
114
  temperature=0.7,
115
  top_p=0.9,
116
  do_sample=True,
117
  pad_token_id=tokenizer.eos_token_id,
118
  repetition_penalty=1.2,
119
  no_repeat_ngram_size=3,
120
- eos_token_id=tokenizer.eos_token_id,
121
- early_stopping=True
 
 
122
  )
123
 
124
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
125
- response = response.split("<|assistant|>")[-1].strip()
 
 
 
 
 
 
 
 
126
 
127
- # Add signature if not present
128
  if "- DIANA 💪" not in response:
129
- response = response + "\n\n- DIANA 💪"
130
 
131
- # Response validation and fallbacks
132
- if not response or len(response.split()) < 20:
133
- if is_greeting(query.prompt):
134
- return {
135
- "response": "Hi there! I'm DIANA, your personal Diet And Nutrition Assistant. I'm here to help you achieve your health and fitness goals! Would you like some advice about workouts or nutrition?\n\n- DIANA 💪"
136
- }
137
- elif is_fitness_question(query.prompt):
138
- return {
139
- "response": "Let me help you on your fitness journey! Could you provide more details about your specific goals and current fitness level? This will help me give you the most relevant advice.\n\n- DIANA 💪"
140
- }
141
- else:
142
- return {
143
- "response": "Hi! I'm DIANA, your Diet And Nutrition Assistant. I specialize in workout plans, diet advice, and general health tips. What would you like to know more about?\n\n- DIANA 💪"
144
- }
145
-
146
  return {"response": response}
147
 
148
  except Exception as e:
149
- print(f"Error during generation: {str(e)}")
150
- raise HTTPException(status_code=500, detail=str(e))
151
 
152
  @app.get("/")
153
  def read_root():
154
- return {
155
- "status": "DIANA (Diet And Nutrition Assistant) is running!",
156
- "model_loaded": MODEL_LOADED,
157
- "specialties": ["Personalized workout advice", "Nutrition guidance", "Fitness planning"]
158
- }
159
 
160
  if __name__ == "__main__":
161
  uvicorn.run("app:app", host="0.0.0.0", port=7860)
 
11
 
12
  app = FastAPI(title="DIANA - Diet And Nutrition Assistant")
13
 
14
+ app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
15
+
16
+ DEVICE = torch.device('cpu')
17
+ torch.set_num_threads(4)
18
+ torch.set_grad_enabled(False)
 
 
19
 
20
  model = None
21
  tokenizer = None
 
26
  try:
27
  print("Starting model load...")
28
  model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
 
29
 
30
  tokenizer = AutoTokenizer.from_pretrained(
31
  model_name,
32
+ cache_dir='/tmp/transformers_cache',
33
+ use_fast=True
34
  )
35
 
36
  model = AutoModelForCausalLM.from_pretrained(
 
39
  low_cpu_mem_usage=True,
40
  device_map=None,
41
  cache_dir='/tmp/transformers_cache'
42
+ ).to(DEVICE)
43
 
44
  model.eval()
45
  MODEL_LOADED = True
 
54
 
55
  class Query(BaseModel):
56
  prompt: str
57
+ max_length: int = 150
58
  temperature: float = 0.7
59
 
60
+ def get_structured_response(topic):
61
+ return f"""Here's what you need to know about {topic}:
62
+
63
+ 1. Start with the basics:
64
+ • Begin gradually
65
+ • Focus on proper form
66
+ • Stay consistent
67
+
68
+ 2. Key points to remember:
69
+ • Set realistic goals
70
+ • Track your progress
71
+ • Listen to your body
72
+
73
+ 3. Tips for success:
74
+ • Start today, not tomorrow
75
+ • Keep it simple
76
+ • Stay motivated
77
+
78
+ Need more specific advice about any of these points?
79
 
80
+ - DIANA 💪"""
81
+
82
+ def is_greeting(text):
83
+ return any(g in text.lower() for g in ['hi', 'hello', 'hey'])
 
 
 
84
 
85
  @app.post("/chat")
86
  async def chat(query: Query):
87
  if not MODEL_LOADED:
88
+ raise HTTPException(status_code=503, detail="DIANA is initializing. Please try again.")
 
 
 
 
89
 
90
  try:
91
+ # Handle greetings
92
  if is_greeting(query.prompt):
93
+ return {"response": "Hi! I'm DIANA, your fitness assistant. How can I help you today?\n\n- DIANA 💪"}
 
 
 
 
 
 
 
 
 
 
94
 
95
+ # Optimized but complete prompt template
96
+ system_prompt = f"""You are DIANA, a fitness assistant. Give clear, complete advice about {query.prompt}.
97
+ Structure your response like this:
98
+ 1. Brief welcome and intro
99
+ 2. 3 main points with bullets
100
+ 3. Encouraging conclusion
101
+ 4. Sign with '- DIANA 💪'
102
+ IMPORTANT: Never end mid-sentence. Always complete your thoughts."""
103
+
104
+ formatted_prompt = f"<|system|>{system_prompt}</s><|user|>Give structured fitness advice about: {query.prompt}</s><|assistant|>Let me help you with that!\n\n"
105
 
106
  inputs = tokenizer(
107
  formatted_prompt,
108
  return_tensors="pt",
109
  truncation=True,
110
+ max_length=200,
111
+ padding=False
112
+ ).to(DEVICE)
113
 
114
+ with torch.inference_mode():
115
  outputs = model.generate(
116
  inputs["input_ids"],
117
+ max_new_tokens=150,
118
+ min_new_tokens=100, # Ensure minimum length
119
  temperature=0.7,
120
  top_p=0.9,
121
  do_sample=True,
122
  pad_token_id=tokenizer.eos_token_id,
123
  repetition_penalty=1.2,
124
  no_repeat_ngram_size=3,
125
+ eos_token_id=tokenizer.eos_token_id, # Proper ending
126
+ num_beams=1,
127
+ early_stopping=True,
128
+ use_cache=True
129
  )
130
 
131
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
132
+ response = response.split("Let me help you with that!")[-1].strip()
133
+
134
+ # Validate response completeness
135
+ sentences = [s.strip() for s in response.split('.') if s.strip()]
136
+ words = response.split()
137
+
138
+ # If response might be incomplete, use structured format
139
+ if len(sentences) < 4 or len(words) < 50 or not response.endswith(('!', '.', '?', '💪')):
140
+ return {"response": get_structured_response(query.prompt)}
141
 
142
+ # Ensure proper signature
143
  if "- DIANA 💪" not in response:
144
+ response += "\n\n- DIANA 💪"
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  return {"response": response}
147
 
148
  except Exception as e:
149
+ print(f"Error: {str(e)}")
150
+ return {"response": get_structured_response(query.prompt)}
151
 
152
  @app.get("/")
153
  def read_root():
154
+ return {"status": "DIANA is ready!", "model_loaded": MODEL_LOADED}
 
 
 
 
155
 
156
  if __name__ == "__main__":
157
  uvicorn.run("app:app", host="0.0.0.0", port=7860)