File size: 5,013 Bytes
c05461e
9f31314
c05461e
005eafc
c05461e
 
9f31314
c05461e
b91dc30
 
 
485b23d
c05461e
a542700
 
 
 
 
a4a53e5
b59f9c5
 
 
9f31314
b59f9c5
 
 
 
 
 
 
 
a542700
 
b59f9c5
 
 
 
b91dc30
b59f9c5
485b23d
 
a542700
b59f9c5
b91dc30
b59f9c5
 
 
 
 
 
9f31314
485b23d
b59f9c5
a4a53e5
 
 
a542700
b59f9c5
9f31314
a542700
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485b23d
a542700
 
 
 
485b23d
a4a53e5
c05461e
9f31314
a542700
9f31314
c05461e
a542700
485b23d
a542700
485b23d
a542700
 
 
 
 
 
 
 
 
 
c05461e
e4aff5c
9f31314
e4aff5c
 
a542700
 
 
c05461e
a542700
c05461e
 
a542700
 
485b23d
c05461e
 
e1a117d
485b23d
 
a542700
 
 
 
c05461e
 
 
a542700
 
 
 
 
 
 
 
 
c05461e
a542700
485b23d
a542700
485b23d
e1a117d
b59f9c5
c05461e
a542700
 
005eafc
e1a117d
 
a542700
e1a117d
c05461e
e1a117d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import uvicorn
import os

os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
os.environ['TORCH_HOME'] = '/tmp/torch_cache'

app = FastAPI(title="DIANA - Diet And Nutrition Assistant")

app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])

DEVICE = torch.device('cpu')
torch.set_num_threads(4)
torch.set_grad_enabled(False)

model = None
tokenizer = None
MODEL_LOADED = False

def load_model():
    global model, tokenizer, MODEL_LOADED
    try:
        print("Starting model load...")
        model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
        
        tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            cache_dir='/tmp/transformers_cache',
            use_fast=True
        )
        
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float32,
            low_cpu_mem_usage=True,
            device_map=None,
            cache_dir='/tmp/transformers_cache'
        ).to(DEVICE)
        
        model.eval()
        MODEL_LOADED = True
        return True
    except Exception as e:
        print(f"Error loading model: {str(e)}")
        MODEL_LOADED = False
        return False

print("Initiating DIANA...")
load_model()

class Query(BaseModel):
    prompt: str
    max_length: int = 150
    temperature: float = 0.7

def get_structured_response(topic):
    return f"""Here's what you need to know about {topic}:

1. Start with the basics:
   β€’ Begin gradually
   β€’ Focus on proper form
   β€’ Stay consistent

2. Key points to remember:
   β€’ Set realistic goals
   β€’ Track your progress
   β€’ Listen to your body

3. Tips for success:
   β€’ Start today, not tomorrow
   β€’ Keep it simple
   β€’ Stay motivated

Need more specific advice about any of these points?

- DIANA πŸ’ͺ"""

def is_greeting(text):
    return any(g in text.lower() for g in ['hi', 'hello', 'hey'])

@app.post("/chat")
async def chat(query: Query):
    if not MODEL_LOADED:
        raise HTTPException(status_code=503, detail="DIANA is initializing. Please try again.")
    
    try:
        # Handle greetings
        if is_greeting(query.prompt):
            return {"response": "Hi! I'm DIANA, your fitness assistant. How can I help you today?\n\n- DIANA πŸ’ͺ"}
        
        # Optimized but complete prompt template
        system_prompt = f"""You are DIANA, a fitness assistant. Give clear, complete advice about {query.prompt}.
        Structure your response like this:
        1. Brief welcome and intro
        2. 3 main points with bullets
        3. Encouraging conclusion
        4. Sign with '- DIANA πŸ’ͺ'
        IMPORTANT: Never end mid-sentence. Always complete your thoughts."""

        formatted_prompt = f"<|system|>{system_prompt}</s><|user|>Give structured fitness advice about: {query.prompt}</s><|assistant|>Let me help you with that!\n\n"
        
        inputs = tokenizer(
            formatted_prompt,
            return_tensors="pt",
            truncation=True,
            max_length=200,
            padding=False
        ).to(DEVICE)
        
        with torch.inference_mode():
            outputs = model.generate(
                inputs["input_ids"],
                max_new_tokens=150,
                min_new_tokens=100,  # Ensure minimum length
                temperature=0.7,
                top_p=0.9,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id,
                repetition_penalty=1.2,
                no_repeat_ngram_size=3,
                eos_token_id=tokenizer.eos_token_id,  # Proper ending
                num_beams=1,
                early_stopping=True,
                use_cache=True
            )
        
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        response = response.split("Let me help you with that!")[-1].strip()
        
        # Validate response completeness
        sentences = [s.strip() for s in response.split('.') if s.strip()]
        words = response.split()
        
        # If response might be incomplete, use structured format
        if len(sentences) < 4 or len(words) < 50 or not response.endswith(('!', '.', '?', 'πŸ’ͺ')):
            return {"response": get_structured_response(query.prompt)}
        
        # Ensure proper signature
        if "- DIANA πŸ’ͺ" not in response:
            response += "\n\n- DIANA πŸ’ͺ"
        
        return {"response": response}
        
    except Exception as e:
        print(f"Error: {str(e)}")
        return {"response": get_structured_response(query.prompt)}

@app.get("/")
def read_root():
    return {"status": "DIANA is ready!", "model_loaded": MODEL_LOADED}

if __name__ == "__main__":
    uvicorn.run("app:app", host="0.0.0.0", port=7860)