from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from transformers import AutoModelForCausalLM, AutoTokenizer import torch import uvicorn import os os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache' os.environ['TORCH_HOME'] = '/tmp/torch_cache' app = FastAPI(title="DIANA - Diet And Nutrition Assistant") app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]) DEVICE = torch.device('cpu') torch.set_num_threads(4) torch.set_grad_enabled(False) model = None tokenizer = None MODEL_LOADED = False def load_model(): global model, tokenizer, MODEL_LOADED try: print("Starting model load...") model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" tokenizer = AutoTokenizer.from_pretrained( model_name, cache_dir='/tmp/transformers_cache', use_fast=True ) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float32, low_cpu_mem_usage=True, device_map=None, cache_dir='/tmp/transformers_cache' ).to(DEVICE) model.eval() MODEL_LOADED = True return True except Exception as e: print(f"Error loading model: {str(e)}") MODEL_LOADED = False return False print("Initiating DIANA...") load_model() class Query(BaseModel): prompt: str max_length: int = 150 temperature: float = 0.7 def get_structured_response(topic): return f"""Here's what you need to know about {topic}: 1. Start with the basics: • Begin gradually • Focus on proper form • Stay consistent 2. Key points to remember: • Set realistic goals • Track your progress • Listen to your body 3. Tips for success: • Start today, not tomorrow • Keep it simple • Stay motivated Need more specific advice about any of these points? - DIANA 💪""" def is_greeting(text): return any(g in text.lower() for g in ['hi', 'hello', 'hey']) @app.post("/chat") async def chat(query: Query): if not MODEL_LOADED: raise HTTPException(status_code=503, detail="DIANA is initializing. Please try again.") try: # Handle greetings if is_greeting(query.prompt): return {"response": "Hi! I'm DIANA, your fitness assistant. How can I help you today?\n\n- DIANA 💪"} # Optimized but complete prompt template system_prompt = f"""You are DIANA, a fitness assistant. Give clear, complete advice about {query.prompt}. Structure your response like this: 1. Brief welcome and intro 2. 3 main points with bullets 3. Encouraging conclusion 4. Sign with '- DIANA 💪' IMPORTANT: Never end mid-sentence. Always complete your thoughts.""" formatted_prompt = f"<|system|>{system_prompt}<|user|>Give structured fitness advice about: {query.prompt}<|assistant|>Let me help you with that!\n\n" inputs = tokenizer( formatted_prompt, return_tensors="pt", truncation=True, max_length=200, padding=False ).to(DEVICE) with torch.inference_mode(): outputs = model.generate( inputs["input_ids"], max_new_tokens=150, min_new_tokens=100, # Ensure minimum length temperature=0.7, top_p=0.9, do_sample=True, pad_token_id=tokenizer.eos_token_id, repetition_penalty=1.2, no_repeat_ngram_size=3, eos_token_id=tokenizer.eos_token_id, # Proper ending num_beams=1, early_stopping=True, use_cache=True ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) response = response.split("Let me help you with that!")[-1].strip() # Validate response completeness sentences = [s.strip() for s in response.split('.') if s.strip()] words = response.split() # If response might be incomplete, use structured format if len(sentences) < 4 or len(words) < 50 or not response.endswith(('!', '.', '?', '💪')): return {"response": get_structured_response(query.prompt)} # Ensure proper signature if "- DIANA 💪" not in response: response += "\n\n- DIANA 💪" return {"response": response} except Exception as e: print(f"Error: {str(e)}") return {"response": get_structured_response(query.prompt)} @app.get("/") def read_root(): return {"status": "DIANA is ready!", "model_loaded": MODEL_LOADED} if __name__ == "__main__": uvicorn.run("app:app", host="0.0.0.0", port=7860)