Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -6,11 +6,10 @@ import torch
|
|
| 6 |
import uvicorn
|
| 7 |
import os
|
| 8 |
|
| 9 |
-
# Set cache directories to /tmp which is writable
|
| 10 |
os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
|
| 11 |
os.environ['TORCH_HOME'] = '/tmp/torch_cache'
|
| 12 |
|
| 13 |
-
app = FastAPI(title="
|
| 14 |
|
| 15 |
app.add_middleware(
|
| 16 |
CORSMiddleware,
|
|
@@ -20,7 +19,6 @@ app.add_middleware(
|
|
| 20 |
allow_headers=["*"],
|
| 21 |
)
|
| 22 |
|
| 23 |
-
# Global variables
|
| 24 |
model = None
|
| 25 |
tokenizer = None
|
| 26 |
MODEL_LOADED = False
|
|
@@ -30,81 +28,120 @@ def load_model():
|
|
| 30 |
try:
|
| 31 |
print("Starting model load...")
|
| 32 |
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
| 33 |
-
|
| 34 |
-
# CPU-specific settings
|
| 35 |
torch.set_num_threads(4)
|
| 36 |
|
| 37 |
-
print("Loading tokenizer...")
|
| 38 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 39 |
model_name,
|
| 40 |
-
cache_dir='/tmp/transformers_cache'
|
| 41 |
)
|
| 42 |
|
| 43 |
-
print("Loading model...")
|
| 44 |
model = AutoModelForCausalLM.from_pretrained(
|
| 45 |
model_name,
|
| 46 |
torch_dtype=torch.float32,
|
| 47 |
low_cpu_mem_usage=True,
|
| 48 |
-
device_map=None,
|
| 49 |
-
cache_dir='/tmp/transformers_cache'
|
| 50 |
)
|
| 51 |
|
| 52 |
model.eval()
|
| 53 |
MODEL_LOADED = True
|
| 54 |
-
print("Model loaded successfully on CPU!")
|
| 55 |
return True
|
| 56 |
except Exception as e:
|
| 57 |
print(f"Error loading model: {str(e)}")
|
| 58 |
MODEL_LOADED = False
|
| 59 |
return False
|
| 60 |
|
| 61 |
-
|
| 62 |
-
print("Initiating model load...")
|
| 63 |
load_model()
|
| 64 |
|
| 65 |
class Query(BaseModel):
|
| 66 |
prompt: str
|
| 67 |
-
max_length: int =
|
| 68 |
temperature: float = 0.7
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
@app.post("/chat")
|
| 71 |
async def chat(query: Query):
|
| 72 |
-
global model, tokenizer, MODEL_LOADED
|
| 73 |
-
|
| 74 |
if not MODEL_LOADED:
|
| 75 |
if not load_model():
|
| 76 |
raise HTTPException(
|
| 77 |
status_code=503,
|
| 78 |
-
detail="
|
| 79 |
)
|
| 80 |
|
| 81 |
try:
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
inputs = tokenizer(
|
| 85 |
formatted_prompt,
|
| 86 |
return_tensors="pt",
|
| 87 |
truncation=True,
|
| 88 |
-
max_length=
|
| 89 |
)
|
| 90 |
|
| 91 |
with torch.no_grad():
|
| 92 |
outputs = model.generate(
|
| 93 |
inputs["input_ids"],
|
| 94 |
-
max_new_tokens=
|
| 95 |
-
|
|
|
|
| 96 |
top_p=0.9,
|
| 97 |
do_sample=True,
|
| 98 |
pad_token_id=tokenizer.eos_token_id,
|
| 99 |
-
|
|
|
|
|
|
|
| 100 |
early_stopping=True
|
| 101 |
)
|
| 102 |
|
| 103 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 104 |
response = response.split("<|assistant|>")[-1].strip()
|
| 105 |
|
| 106 |
-
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
return {"response": response}
|
| 110 |
|
|
@@ -115,18 +152,9 @@ async def chat(query: Query):
|
|
| 115 |
@app.get("/")
|
| 116 |
def read_root():
|
| 117 |
return {
|
| 118 |
-
"status": "
|
| 119 |
-
"model_loaded": MODEL_LOADED,
|
| 120 |
-
"backend": "CPU"
|
| 121 |
-
}
|
| 122 |
-
|
| 123 |
-
@app.get("/debug")
|
| 124 |
-
def debug_info():
|
| 125 |
-
return {
|
| 126 |
"model_loaded": MODEL_LOADED,
|
| 127 |
-
"
|
| 128 |
-
"num_threads": torch.get_num_threads(),
|
| 129 |
-
"cache_dir": os.environ.get('TRANSFORMERS_CACHE')
|
| 130 |
}
|
| 131 |
|
| 132 |
if __name__ == "__main__":
|
|
|
|
| 6 |
import uvicorn
|
| 7 |
import os
|
| 8 |
|
|
|
|
| 9 |
os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
|
| 10 |
os.environ['TORCH_HOME'] = '/tmp/torch_cache'
|
| 11 |
|
| 12 |
+
app = FastAPI(title="DIANA - Diet And Nutrition Assistant")
|
| 13 |
|
| 14 |
app.add_middleware(
|
| 15 |
CORSMiddleware,
|
|
|
|
| 19 |
allow_headers=["*"],
|
| 20 |
)
|
| 21 |
|
|
|
|
| 22 |
model = None
|
| 23 |
tokenizer = None
|
| 24 |
MODEL_LOADED = False
|
|
|
|
| 28 |
try:
|
| 29 |
print("Starting model load...")
|
| 30 |
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
|
|
|
|
|
|
| 31 |
torch.set_num_threads(4)
|
| 32 |
|
|
|
|
| 33 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 34 |
model_name,
|
| 35 |
+
cache_dir='/tmp/transformers_cache'
|
| 36 |
)
|
| 37 |
|
|
|
|
| 38 |
model = AutoModelForCausalLM.from_pretrained(
|
| 39 |
model_name,
|
| 40 |
torch_dtype=torch.float32,
|
| 41 |
low_cpu_mem_usage=True,
|
| 42 |
+
device_map=None,
|
| 43 |
+
cache_dir='/tmp/transformers_cache'
|
| 44 |
)
|
| 45 |
|
| 46 |
model.eval()
|
| 47 |
MODEL_LOADED = True
|
|
|
|
| 48 |
return True
|
| 49 |
except Exception as e:
|
| 50 |
print(f"Error loading model: {str(e)}")
|
| 51 |
MODEL_LOADED = False
|
| 52 |
return False
|
| 53 |
|
| 54 |
+
print("Initiating DIANA...")
|
|
|
|
| 55 |
load_model()
|
| 56 |
|
| 57 |
class Query(BaseModel):
|
| 58 |
prompt: str
|
| 59 |
+
max_length: int = 200
|
| 60 |
temperature: float = 0.7
|
| 61 |
|
| 62 |
+
def is_greeting(text):
|
| 63 |
+
greetings = ['hi', 'hello', 'hey', 'good morning', 'good afternoon', 'good evening', 'greetings']
|
| 64 |
+
return any(greeting in text.lower() for greeting in greetings)
|
| 65 |
+
|
| 66 |
+
def is_fitness_question(text):
|
| 67 |
+
fitness_keywords = [
|
| 68 |
+
'workout', 'exercise', 'training', 'muscle', 'strength', 'cardio', 'weight',
|
| 69 |
+
'diet', 'nutrition', 'protein', 'carbs', 'fat', 'meal', 'food', 'eating',
|
| 70 |
+
'routine', 'program', 'sets', 'reps', 'gym', 'fitness', 'health'
|
| 71 |
+
]
|
| 72 |
+
return any(keyword in text.lower() for keyword in fitness_keywords)
|
| 73 |
+
|
| 74 |
@app.post("/chat")
|
| 75 |
async def chat(query: Query):
|
|
|
|
|
|
|
| 76 |
if not MODEL_LOADED:
|
| 77 |
if not load_model():
|
| 78 |
raise HTTPException(
|
| 79 |
status_code=503,
|
| 80 |
+
detail="DIANA is still initializing. Please try again in a minute."
|
| 81 |
)
|
| 82 |
|
| 83 |
try:
|
| 84 |
+
# Personalized system prompts
|
| 85 |
+
if is_greeting(query.prompt):
|
| 86 |
+
system_prompt = """You are DIANA (Diet And Nutrition Assistant), a friendly and knowledgeable
|
| 87 |
+
fitness companion. Always respond warmly and offer to help with fitness and nutrition guidance.
|
| 88 |
+
Sign your responses with '- DIANA πͺ'"""
|
| 89 |
+
else:
|
| 90 |
+
system_prompt = """You are DIANA (Diet And Nutrition Assistant), a knowledgeable fitness and
|
| 91 |
+
nutrition guide. Provide practical, safe, and evidence-based advice about workouts, nutrition,
|
| 92 |
+
and healthy living. Include:
|
| 93 |
+
1. Clear, actionable recommendations
|
| 94 |
+
2. Safety considerations
|
| 95 |
+
3. Beginner-friendly explanations
|
| 96 |
+
Remember to sign your responses with '- DIANA πͺ'"""
|
| 97 |
+
|
| 98 |
+
formatted_prompt = f"""<|system|>{system_prompt}</s>
|
| 99 |
+
<|user|>{query.prompt}</s>
|
| 100 |
+
<|assistant|>"""
|
| 101 |
|
| 102 |
inputs = tokenizer(
|
| 103 |
formatted_prompt,
|
| 104 |
return_tensors="pt",
|
| 105 |
truncation=True,
|
| 106 |
+
max_length=300
|
| 107 |
)
|
| 108 |
|
| 109 |
with torch.no_grad():
|
| 110 |
outputs = model.generate(
|
| 111 |
inputs["input_ids"],
|
| 112 |
+
max_new_tokens=200,
|
| 113 |
+
min_new_tokens=50,
|
| 114 |
+
temperature=0.7,
|
| 115 |
top_p=0.9,
|
| 116 |
do_sample=True,
|
| 117 |
pad_token_id=tokenizer.eos_token_id,
|
| 118 |
+
repetition_penalty=1.2,
|
| 119 |
+
no_repeat_ngram_size=3,
|
| 120 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 121 |
early_stopping=True
|
| 122 |
)
|
| 123 |
|
| 124 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 125 |
response = response.split("<|assistant|>")[-1].strip()
|
| 126 |
|
| 127 |
+
# Add signature if not present
|
| 128 |
+
if "- DIANA πͺ" not in response:
|
| 129 |
+
response = response + "\n\n- DIANA πͺ"
|
| 130 |
+
|
| 131 |
+
# Response validation and fallbacks
|
| 132 |
+
if not response or len(response.split()) < 20:
|
| 133 |
+
if is_greeting(query.prompt):
|
| 134 |
+
return {
|
| 135 |
+
"response": "Hi there! I'm DIANA, your personal Diet And Nutrition Assistant. I'm here to help you achieve your health and fitness goals! Would you like some advice about workouts or nutrition?\n\n- DIANA πͺ"
|
| 136 |
+
}
|
| 137 |
+
elif is_fitness_question(query.prompt):
|
| 138 |
+
return {
|
| 139 |
+
"response": "Let me help you on your fitness journey! Could you provide more details about your specific goals and current fitness level? This will help me give you the most relevant advice.\n\n- DIANA πͺ"
|
| 140 |
+
}
|
| 141 |
+
else:
|
| 142 |
+
return {
|
| 143 |
+
"response": "Hi! I'm DIANA, your Diet And Nutrition Assistant. I specialize in workout plans, diet advice, and general health tips. What would you like to know more about?\n\n- DIANA πͺ"
|
| 144 |
+
}
|
| 145 |
|
| 146 |
return {"response": response}
|
| 147 |
|
|
|
|
| 152 |
@app.get("/")
|
| 153 |
def read_root():
|
| 154 |
return {
|
| 155 |
+
"status": "DIANA (Diet And Nutrition Assistant) is running!",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
"model_loaded": MODEL_LOADED,
|
| 157 |
+
"specialties": ["Personalized workout advice", "Nutrition guidance", "Fitness planning"]
|
|
|
|
|
|
|
| 158 |
}
|
| 159 |
|
| 160 |
if __name__ == "__main__":
|