LucianStorm's picture
Update app.py
a542700 verified
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import uvicorn
import os
os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
os.environ['TORCH_HOME'] = '/tmp/torch_cache'
app = FastAPI(title="DIANA - Diet And Nutrition Assistant")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
DEVICE = torch.device('cpu')
torch.set_num_threads(4)
torch.set_grad_enabled(False)
model = None
tokenizer = None
MODEL_LOADED = False
def load_model():
global model, tokenizer, MODEL_LOADED
try:
print("Starting model load...")
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(
model_name,
cache_dir='/tmp/transformers_cache',
use_fast=True
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
device_map=None,
cache_dir='/tmp/transformers_cache'
).to(DEVICE)
model.eval()
MODEL_LOADED = True
return True
except Exception as e:
print(f"Error loading model: {str(e)}")
MODEL_LOADED = False
return False
print("Initiating DIANA...")
load_model()
class Query(BaseModel):
prompt: str
max_length: int = 150
temperature: float = 0.7
def get_structured_response(topic):
return f"""Here's what you need to know about {topic}:
1. Start with the basics:
β€’ Begin gradually
β€’ Focus on proper form
β€’ Stay consistent
2. Key points to remember:
β€’ Set realistic goals
β€’ Track your progress
β€’ Listen to your body
3. Tips for success:
β€’ Start today, not tomorrow
β€’ Keep it simple
β€’ Stay motivated
Need more specific advice about any of these points?
- DIANA πŸ’ͺ"""
def is_greeting(text):
return any(g in text.lower() for g in ['hi', 'hello', 'hey'])
@app.post("/chat")
async def chat(query: Query):
if not MODEL_LOADED:
raise HTTPException(status_code=503, detail="DIANA is initializing. Please try again.")
try:
# Handle greetings
if is_greeting(query.prompt):
return {"response": "Hi! I'm DIANA, your fitness assistant. How can I help you today?\n\n- DIANA πŸ’ͺ"}
# Optimized but complete prompt template
system_prompt = f"""You are DIANA, a fitness assistant. Give clear, complete advice about {query.prompt}.
Structure your response like this:
1. Brief welcome and intro
2. 3 main points with bullets
3. Encouraging conclusion
4. Sign with '- DIANA πŸ’ͺ'
IMPORTANT: Never end mid-sentence. Always complete your thoughts."""
formatted_prompt = f"<|system|>{system_prompt}</s><|user|>Give structured fitness advice about: {query.prompt}</s><|assistant|>Let me help you with that!\n\n"
inputs = tokenizer(
formatted_prompt,
return_tensors="pt",
truncation=True,
max_length=200,
padding=False
).to(DEVICE)
with torch.inference_mode():
outputs = model.generate(
inputs["input_ids"],
max_new_tokens=150,
min_new_tokens=100, # Ensure minimum length
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
repetition_penalty=1.2,
no_repeat_ngram_size=3,
eos_token_id=tokenizer.eos_token_id, # Proper ending
num_beams=1,
early_stopping=True,
use_cache=True
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response.split("Let me help you with that!")[-1].strip()
# Validate response completeness
sentences = [s.strip() for s in response.split('.') if s.strip()]
words = response.split()
# If response might be incomplete, use structured format
if len(sentences) < 4 or len(words) < 50 or not response.endswith(('!', '.', '?', 'πŸ’ͺ')):
return {"response": get_structured_response(query.prompt)}
# Ensure proper signature
if "- DIANA πŸ’ͺ" not in response:
response += "\n\n- DIANA πŸ’ͺ"
return {"response": response}
except Exception as e:
print(f"Error: {str(e)}")
return {"response": get_structured_response(query.prompt)}
@app.get("/")
def read_root():
return {"status": "DIANA is ready!", "model_loaded": MODEL_LOADED}
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=7860)