Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -11,13 +11,11 @@ os.environ['TORCH_HOME'] = '/tmp/torch_cache'
|
|
| 11 |
|
| 12 |
app = FastAPI(title="DIANA - Diet And Nutrition Assistant")
|
| 13 |
|
| 14 |
-
app.add_middleware(
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
allow_headers=["*"],
|
| 20 |
-
)
|
| 21 |
|
| 22 |
model = None
|
| 23 |
tokenizer = None
|
|
@@ -28,11 +26,11 @@ def load_model():
|
|
| 28 |
try:
|
| 29 |
print("Starting model load...")
|
| 30 |
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
| 31 |
-
torch.set_num_threads(4)
|
| 32 |
|
| 33 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 34 |
model_name,
|
| 35 |
-
cache_dir='/tmp/transformers_cache'
|
|
|
|
| 36 |
)
|
| 37 |
|
| 38 |
model = AutoModelForCausalLM.from_pretrained(
|
|
@@ -41,7 +39,7 @@ def load_model():
|
|
| 41 |
low_cpu_mem_usage=True,
|
| 42 |
device_map=None,
|
| 43 |
cache_dir='/tmp/transformers_cache'
|
| 44 |
-
)
|
| 45 |
|
| 46 |
model.eval()
|
| 47 |
MODEL_LOADED = True
|
|
@@ -56,106 +54,104 @@ load_model()
|
|
| 56 |
|
| 57 |
class Query(BaseModel):
|
| 58 |
prompt: str
|
| 59 |
-
max_length: int =
|
| 60 |
temperature: float = 0.7
|
| 61 |
|
| 62 |
-
def
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
'routine', 'program', 'sets', 'reps', 'gym', 'fitness', 'health'
|
| 71 |
-
]
|
| 72 |
-
return any(keyword in text.lower() for keyword in fitness_keywords)
|
| 73 |
|
| 74 |
@app.post("/chat")
|
| 75 |
async def chat(query: Query):
|
| 76 |
if not MODEL_LOADED:
|
| 77 |
-
|
| 78 |
-
raise HTTPException(
|
| 79 |
-
status_code=503,
|
| 80 |
-
detail="DIANA is still initializing. Please try again in a minute."
|
| 81 |
-
)
|
| 82 |
|
| 83 |
try:
|
| 84 |
-
#
|
| 85 |
if is_greeting(query.prompt):
|
| 86 |
-
|
| 87 |
-
fitness companion. Always respond warmly and offer to help with fitness and nutrition guidance.
|
| 88 |
-
Sign your responses with '- DIANA 💪'"""
|
| 89 |
-
else:
|
| 90 |
-
system_prompt = """You are DIANA (Diet And Nutrition Assistant), a knowledgeable fitness and
|
| 91 |
-
nutrition guide. Provide practical, safe, and evidence-based advice about workouts, nutrition,
|
| 92 |
-
and healthy living. Include:
|
| 93 |
-
1. Clear, actionable recommendations
|
| 94 |
-
2. Safety considerations
|
| 95 |
-
3. Beginner-friendly explanations
|
| 96 |
-
Remember to sign your responses with '- DIANA 💪'"""
|
| 97 |
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
inputs = tokenizer(
|
| 103 |
formatted_prompt,
|
| 104 |
return_tensors="pt",
|
| 105 |
truncation=True,
|
| 106 |
-
max_length=
|
| 107 |
-
|
|
|
|
| 108 |
|
| 109 |
-
with torch.
|
| 110 |
outputs = model.generate(
|
| 111 |
inputs["input_ids"],
|
| 112 |
-
max_new_tokens=
|
| 113 |
-
min_new_tokens=
|
| 114 |
temperature=0.7,
|
| 115 |
top_p=0.9,
|
| 116 |
do_sample=True,
|
| 117 |
pad_token_id=tokenizer.eos_token_id,
|
| 118 |
repetition_penalty=1.2,
|
| 119 |
no_repeat_ngram_size=3,
|
| 120 |
-
eos_token_id=tokenizer.eos_token_id,
|
| 121 |
-
|
|
|
|
|
|
|
| 122 |
)
|
| 123 |
|
| 124 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 125 |
-
response = response.split("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
-
#
|
| 128 |
if "- DIANA 💪" not in response:
|
| 129 |
-
response =
|
| 130 |
|
| 131 |
-
# Response validation and fallbacks
|
| 132 |
-
if not response or len(response.split()) < 20:
|
| 133 |
-
if is_greeting(query.prompt):
|
| 134 |
-
return {
|
| 135 |
-
"response": "Hi there! I'm DIANA, your personal Diet And Nutrition Assistant. I'm here to help you achieve your health and fitness goals! Would you like some advice about workouts or nutrition?\n\n- DIANA 💪"
|
| 136 |
-
}
|
| 137 |
-
elif is_fitness_question(query.prompt):
|
| 138 |
-
return {
|
| 139 |
-
"response": "Let me help you on your fitness journey! Could you provide more details about your specific goals and current fitness level? This will help me give you the most relevant advice.\n\n- DIANA 💪"
|
| 140 |
-
}
|
| 141 |
-
else:
|
| 142 |
-
return {
|
| 143 |
-
"response": "Hi! I'm DIANA, your Diet And Nutrition Assistant. I specialize in workout plans, diet advice, and general health tips. What would you like to know more about?\n\n- DIANA 💪"
|
| 144 |
-
}
|
| 145 |
-
|
| 146 |
return {"response": response}
|
| 147 |
|
| 148 |
except Exception as e:
|
| 149 |
-
print(f"Error
|
| 150 |
-
|
| 151 |
|
| 152 |
@app.get("/")
|
| 153 |
def read_root():
|
| 154 |
-
return {
|
| 155 |
-
"status": "DIANA (Diet And Nutrition Assistant) is running!",
|
| 156 |
-
"model_loaded": MODEL_LOADED,
|
| 157 |
-
"specialties": ["Personalized workout advice", "Nutrition guidance", "Fitness planning"]
|
| 158 |
-
}
|
| 159 |
|
| 160 |
if __name__ == "__main__":
|
| 161 |
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|
|
|
|
| 11 |
|
| 12 |
app = FastAPI(title="DIANA - Diet And Nutrition Assistant")
|
| 13 |
|
| 14 |
+
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
|
| 15 |
+
|
| 16 |
+
DEVICE = torch.device('cpu')
|
| 17 |
+
torch.set_num_threads(4)
|
| 18 |
+
torch.set_grad_enabled(False)
|
|
|
|
|
|
|
| 19 |
|
| 20 |
model = None
|
| 21 |
tokenizer = None
|
|
|
|
| 26 |
try:
|
| 27 |
print("Starting model load...")
|
| 28 |
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
|
|
|
| 29 |
|
| 30 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 31 |
model_name,
|
| 32 |
+
cache_dir='/tmp/transformers_cache',
|
| 33 |
+
use_fast=True
|
| 34 |
)
|
| 35 |
|
| 36 |
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
| 39 |
low_cpu_mem_usage=True,
|
| 40 |
device_map=None,
|
| 41 |
cache_dir='/tmp/transformers_cache'
|
| 42 |
+
).to(DEVICE)
|
| 43 |
|
| 44 |
model.eval()
|
| 45 |
MODEL_LOADED = True
|
|
|
|
| 54 |
|
| 55 |
class Query(BaseModel):
|
| 56 |
prompt: str
|
| 57 |
+
max_length: int = 150
|
| 58 |
temperature: float = 0.7
|
| 59 |
|
| 60 |
+
def get_structured_response(topic):
|
| 61 |
+
return f"""Here's what you need to know about {topic}:
|
| 62 |
+
|
| 63 |
+
1. Start with the basics:
|
| 64 |
+
• Begin gradually
|
| 65 |
+
• Focus on proper form
|
| 66 |
+
• Stay consistent
|
| 67 |
+
|
| 68 |
+
2. Key points to remember:
|
| 69 |
+
• Set realistic goals
|
| 70 |
+
• Track your progress
|
| 71 |
+
• Listen to your body
|
| 72 |
+
|
| 73 |
+
3. Tips for success:
|
| 74 |
+
• Start today, not tomorrow
|
| 75 |
+
• Keep it simple
|
| 76 |
+
• Stay motivated
|
| 77 |
+
|
| 78 |
+
Need more specific advice about any of these points?
|
| 79 |
|
| 80 |
+
- DIANA 💪"""
|
| 81 |
+
|
| 82 |
+
def is_greeting(text):
|
| 83 |
+
return any(g in text.lower() for g in ['hi', 'hello', 'hey'])
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
@app.post("/chat")
|
| 86 |
async def chat(query: Query):
|
| 87 |
if not MODEL_LOADED:
|
| 88 |
+
raise HTTPException(status_code=503, detail="DIANA is initializing. Please try again.")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
try:
|
| 91 |
+
# Handle greetings
|
| 92 |
if is_greeting(query.prompt):
|
| 93 |
+
return {"response": "Hi! I'm DIANA, your fitness assistant. How can I help you today?\n\n- DIANA 💪"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
+
# Optimized but complete prompt template
|
| 96 |
+
system_prompt = f"""You are DIANA, a fitness assistant. Give clear, complete advice about {query.prompt}.
|
| 97 |
+
Structure your response like this:
|
| 98 |
+
1. Brief welcome and intro
|
| 99 |
+
2. 3 main points with bullets
|
| 100 |
+
3. Encouraging conclusion
|
| 101 |
+
4. Sign with '- DIANA 💪'
|
| 102 |
+
IMPORTANT: Never end mid-sentence. Always complete your thoughts."""
|
| 103 |
+
|
| 104 |
+
formatted_prompt = f"<|system|>{system_prompt}</s><|user|>Give structured fitness advice about: {query.prompt}</s><|assistant|>Let me help you with that!\n\n"
|
| 105 |
|
| 106 |
inputs = tokenizer(
|
| 107 |
formatted_prompt,
|
| 108 |
return_tensors="pt",
|
| 109 |
truncation=True,
|
| 110 |
+
max_length=200,
|
| 111 |
+
padding=False
|
| 112 |
+
).to(DEVICE)
|
| 113 |
|
| 114 |
+
with torch.inference_mode():
|
| 115 |
outputs = model.generate(
|
| 116 |
inputs["input_ids"],
|
| 117 |
+
max_new_tokens=150,
|
| 118 |
+
min_new_tokens=100, # Ensure minimum length
|
| 119 |
temperature=0.7,
|
| 120 |
top_p=0.9,
|
| 121 |
do_sample=True,
|
| 122 |
pad_token_id=tokenizer.eos_token_id,
|
| 123 |
repetition_penalty=1.2,
|
| 124 |
no_repeat_ngram_size=3,
|
| 125 |
+
eos_token_id=tokenizer.eos_token_id, # Proper ending
|
| 126 |
+
num_beams=1,
|
| 127 |
+
early_stopping=True,
|
| 128 |
+
use_cache=True
|
| 129 |
)
|
| 130 |
|
| 131 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 132 |
+
response = response.split("Let me help you with that!")[-1].strip()
|
| 133 |
+
|
| 134 |
+
# Validate response completeness
|
| 135 |
+
sentences = [s.strip() for s in response.split('.') if s.strip()]
|
| 136 |
+
words = response.split()
|
| 137 |
+
|
| 138 |
+
# If response might be incomplete, use structured format
|
| 139 |
+
if len(sentences) < 4 or len(words) < 50 or not response.endswith(('!', '.', '?', '💪')):
|
| 140 |
+
return {"response": get_structured_response(query.prompt)}
|
| 141 |
|
| 142 |
+
# Ensure proper signature
|
| 143 |
if "- DIANA 💪" not in response:
|
| 144 |
+
response += "\n\n- DIANA 💪"
|
| 145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
return {"response": response}
|
| 147 |
|
| 148 |
except Exception as e:
|
| 149 |
+
print(f"Error: {str(e)}")
|
| 150 |
+
return {"response": get_structured_response(query.prompt)}
|
| 151 |
|
| 152 |
@app.get("/")
|
| 153 |
def read_root():
|
| 154 |
+
return {"status": "DIANA is ready!", "model_loaded": MODEL_LOADED}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
|
| 156 |
if __name__ == "__main__":
|
| 157 |
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|