Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,12 +7,38 @@ import os
|
|
| 7 |
|
| 8 |
app = FastAPI(title="TinyLlama Fitness Bot")
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
@app.get("/")
|
| 12 |
def read_root():
|
| 13 |
-
return {
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
-
# Test route to check environment
|
| 16 |
@app.get("/debug")
|
| 17 |
def debug_info():
|
| 18 |
return {
|
|
@@ -20,20 +46,15 @@ def debug_info():
|
|
| 20 |
{"path": route.path, "name": route.name}
|
| 21 |
for route in app.routes
|
| 22 |
],
|
| 23 |
-
"model_loaded":
|
| 24 |
-
"
|
| 25 |
}
|
| 26 |
|
| 27 |
-
|
| 28 |
-
prompt: str
|
| 29 |
-
max_length: int = 256
|
| 30 |
-
temperature: float = 0.7
|
| 31 |
-
|
| 32 |
-
class Response(BaseModel):
|
| 33 |
-
response: str
|
| 34 |
-
|
| 35 |
-
@app.post("/chat", response_model=Response)
|
| 36 |
async def chat(query: Query):
|
|
|
|
|
|
|
|
|
|
| 37 |
try:
|
| 38 |
system_prompt = """You are a knowledgeable fitness and nutrition assistant."""
|
| 39 |
formatted_prompt = f"<|system|>{system_prompt}</s><|user|>{query.prompt}</s><|assistant|>"
|
|
|
|
| 7 |
|
| 8 |
app = FastAPI(title="TinyLlama Fitness Bot")
|
| 9 |
|
| 10 |
+
print("Loading model and tokenizer...")
|
| 11 |
+
|
| 12 |
+
# Initialize model and tokenizer globally
|
| 13 |
+
try:
|
| 14 |
+
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
| 15 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 16 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 17 |
+
model_name,
|
| 18 |
+
torch_dtype=torch.float32,
|
| 19 |
+
low_cpu_mem_usage=True
|
| 20 |
+
)
|
| 21 |
+
print("Model and tokenizer loaded successfully!")
|
| 22 |
+
MODEL_LOADED = True
|
| 23 |
+
except Exception as e:
|
| 24 |
+
print(f"Error loading model: {e}")
|
| 25 |
+
MODEL_LOADED = False
|
| 26 |
+
|
| 27 |
+
class Query(BaseModel):
|
| 28 |
+
prompt: str
|
| 29 |
+
max_length: int = 256
|
| 30 |
+
temperature: float = 0.7
|
| 31 |
+
|
| 32 |
+
class Response(BaseModel):
|
| 33 |
+
response: str
|
| 34 |
+
|
| 35 |
@app.get("/")
|
| 36 |
def read_root():
|
| 37 |
+
return {
|
| 38 |
+
"status": "API is running!",
|
| 39 |
+
"model_loaded": MODEL_LOADED
|
| 40 |
+
}
|
| 41 |
|
|
|
|
| 42 |
@app.get("/debug")
|
| 43 |
def debug_info():
|
| 44 |
return {
|
|
|
|
| 46 |
{"path": route.path, "name": route.name}
|
| 47 |
for route in app.routes
|
| 48 |
],
|
| 49 |
+
"model_loaded": MODEL_LOADED,
|
| 50 |
+
"model_name": model_name if MODEL_LOADED else None,
|
| 51 |
}
|
| 52 |
|
| 53 |
+
@app.post("/chat")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
async def chat(query: Query):
|
| 55 |
+
if not MODEL_LOADED:
|
| 56 |
+
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 57 |
+
|
| 58 |
try:
|
| 59 |
system_prompt = """You are a knowledgeable fitness and nutrition assistant."""
|
| 60 |
formatted_prompt = f"<|system|>{system_prompt}</s><|user|>{query.prompt}</s><|assistant|>"
|