saiganesh2004 commited on
Commit
8408778
·
verified ·
1 Parent(s): f94a9e1

Update model_api.py

Browse files
Files changed (1) hide show
  1. model_api.py +75 -9
model_api.py CHANGED
@@ -2,25 +2,91 @@ from huggingface_hub import InferenceClient
2
  import os
3
 
4
  def query_model(prompt):
 
 
 
5
  try:
6
  HF_TOKEN = os.getenv("HF_TOKEN")
7
-
 
 
 
 
8
  client = InferenceClient(
9
  model="mistralai/Mistral-7B-Instruct-v0.2",
10
  token=HF_TOKEN
11
  )
12
-
 
 
 
 
 
 
 
 
 
 
13
  response = client.chat_completion(
14
  messages=[
15
- {"role": "system", "content": "You are a certified professional fitness trainer."},
16
  {"role": "user", "content": prompt}
17
  ],
18
- max_tokens=2500,#BEFORE 600
19
- temperature=0.7
 
20
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- return response.choices[0].message.content
23
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  except Exception as e:
25
- return f"Error: {str(e)}"
26
-
 
2
  import os
3
 
4
  def query_model(prompt):
5
+ """
6
+ Query the Mistral-7B model with the given prompt
7
+ """
8
  try:
9
  HF_TOKEN = os.getenv("HF_TOKEN")
10
+
11
+ if not HF_TOKEN:
12
+ return "Error: HF_TOKEN not found. Please set your Hugging Face token in environment variables."
13
+
14
+ # Initialize the client
15
  client = InferenceClient(
16
  model="mistralai/Mistral-7B-Instruct-v0.2",
17
  token=HF_TOKEN
18
  )
19
+
20
+ # Enhanced system prompt for better responses
21
+ system_prompt = """You are a certified professional fitness trainer with expertise in creating personalized workout plans.
22
+ Always provide complete, detailed workout plans with:
23
+ - Clear day-by-day structure
24
+ - Specific exercises with sets, reps, and rest periods
25
+ - Warm-up and cool-down recommendations
26
+ - Safety considerations based on user's profile
27
+ When asked for a 5-day plan, ensure ALL 5 days are included with clear day headers."""
28
+
29
+ # Make the API call
30
  response = client.chat_completion(
31
  messages=[
32
+ {"role": "system", "content": system_prompt},
33
  {"role": "user", "content": prompt}
34
  ],
35
+ max_tokens=2500, # Increased for complete 5-day plan
36
+ temperature=0.7,
37
+ top_p=0.95
38
  )
39
+
40
+ # Extract and return the response
41
+ workout_plan = response.choices[0].message.content
42
+
43
+ # Verify if the response contains all 5 days
44
+ days_found = sum([f"Day {i}" in workout_plan for i in range(1, 6)])
45
+
46
+ if days_found < 5:
47
+ # If incomplete, try one more time with more explicit instruction
48
+ retry_prompt = prompt + "\n\nIMPORTANT: The previous response was incomplete. Please ensure ALL 5 days (Day 1 through Day 5) are included in the plan. Each day should be clearly marked with 'Day X' header and include 4-6 exercises."
49
+
50
+ retry_response = client.chat_completion(
51
+ messages=[
52
+ {"role": "system", "content": system_prompt},
53
+ {"role": "user", "content": retry_prompt}
54
+ ],
55
+ max_tokens=2500,
56
+ temperature=0.7
57
+ )
58
+ workout_plan = retry_response.choices[0].message.content
59
+
60
+ return workout_plan
61
+
62
+ except Exception as e:
63
+ return f"Error generating workout plan: {str(e)}"
64
 
65
+ def test_api_connection():
66
+ """
67
+ Test function to verify API connection
68
+ """
69
+ try:
70
+ HF_TOKEN = os.getenv("HF_TOKEN")
71
+ if not HF_TOKEN:
72
+ return False, "HF_TOKEN not found"
73
+
74
+ client = InferenceClient(
75
+ model="mistralai/Mistral-7B-Instruct-v0.2",
76
+ token=HF_TOKEN
77
+ )
78
+
79
+ # Simple test prompt
80
+ response = client.chat_completion(
81
+ messages=[
82
+ {"role": "system", "content": "You are a helpful assistant."},
83
+ {"role": "user", "content": "Say 'API connection successful' if you can read this."}
84
+ ],
85
+ max_tokens=50,
86
+ temperature=0.1
87
+ )
88
+
89
+ return True, "API connection successful"
90
+
91
  except Exception as e:
92
+ return False, f"API connection failed: {str(e)}"