saiganesh2004 commited on
Commit
74ece72
·
verified ·
1 Parent(s): 2164ce5

Update model_api.py

Browse files
Files changed (1) hide show
  1. model_api.py +146 -68
model_api.py CHANGED
@@ -1,92 +1,170 @@
1
  from huggingface_hub import InferenceClient
2
  import os
 
 
 
 
 
3
 
4
  def query_model(prompt):
5
  """
6
- Query the Mistral-7B model with the given prompt
 
7
  """
8
  try:
9
- HF_TOKEN = os.getenv("HF_TOKEN")
10
-
11
- if not HF_TOKEN:
12
- return "Error: HF_TOKEN not found. Please set your Hugging Face token in environment variables."
13
-
14
- # Initialize the client
15
- client = InferenceClient(
16
- model="mistralai/Mistral-7B-Instruct-v0.2",
17
- token=HF_TOKEN
18
- )
19
 
20
- # Enhanced system prompt for better responses
21
- system_prompt = """You are a certified professional fitness trainer with expertise in creating personalized workout plans.
22
- Always provide complete, detailed workout plans with:
23
- - Clear day-by-day structure
24
- - Specific exercises with sets, reps, and rest periods
25
- - Warm-up and cool-down recommendations
26
- - Safety considerations based on user's profile
27
- When asked for a 5-day plan, ensure ALL 5 days are included with clear day headers."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- # Make the API call
30
- response = client.chat_completion(
31
  messages=[
32
  {"role": "system", "content": system_prompt},
33
- {"role": "user", "content": prompt}
34
  ],
35
- max_tokens=3000, # Increased for complete 5-day plan
36
- temperature=0.7,
37
- top_p=0.95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  )
39
 
40
- # Extract and return the response
41
- workout_plan = response.choices[0].message.content
42
-
43
- # Verify if the response contains all 5 days
44
- days_found = sum([f"Day {i}" in workout_plan for i in range(1, 6)])
45
-
46
- if days_found < 5:
47
- # If incomplete, try one more time with more explicit instruction
48
- retry_prompt = prompt + "\n\nIMPORTANT: The previous response was incomplete. Please ensure ALL 5 days (Day 1 through Day 5) are included in the plan. Each day should be clearly marked with 'Day X' header and include 4-6 exercises."
49
 
50
- retry_response = client.chat_completion(
51
- messages=[
52
- {"role": "system", "content": system_prompt},
53
- {"role": "user", "content": retry_prompt}
54
- ],
55
- max_tokens=2500,
56
- temperature=0.7
57
- )
58
- workout_plan = retry_response.choices[0].message.content
59
-
60
- return workout_plan
61
-
62
  except Exception as e:
63
- return f"Error generating workout plan: {str(e)}"
64
 
65
  def test_api_connection():
66
  """
67
  Test function to verify API connection
68
  """
69
  try:
70
- HF_TOKEN = os.getenv("HF_TOKEN")
71
- if not HF_TOKEN:
72
- return False, "HF_TOKEN not found"
73
-
74
- client = InferenceClient(
75
- model="mistralai/Mistral-7B-Instruct-v0.2",
76
- token=HF_TOKEN
77
- )
78
-
79
- # Simple test prompt
80
- response = client.chat_completion(
81
- messages=[
82
- {"role": "system", "content": "You are a helpful assistant."},
83
- {"role": "user", "content": "Say 'API connection successful' if you can read this."}
84
- ],
85
- max_tokens=50,
86
- temperature=0.1
87
- )
88
-
89
- return True, "API connection successful"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  except Exception as e:
92
- return False, f"API connection failed: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from huggingface_hub import InferenceClient
2
  import os
3
+ import requests
4
+ import json
5
+
6
+ LLAMA_MODEL = "meta-llama/Llama-3.2-3B-Instruct" # Default model
7
+ USE_LOCAL_OLLAMA = False # Set to True if using local Ollama
8
 
9
  def query_model(prompt):
10
  """
11
+ Query the Llama model with the given prompt
12
+ Supports both Hugging Face Inference API and local Ollama
13
  """
14
  try:
15
+ if USE_LOCAL_OLLAMA:
16
+ return query_ollama(prompt)
17
+ else:
18
+ return query_huggingface(prompt)
 
 
 
 
 
 
19
 
20
+ except Exception as e:
21
+ return f"Error generating workout plan: {str(e)}"
22
+
23
+ def query_huggingface(prompt):
24
+ """
25
+ Query Llama via Hugging Face Inference API
26
+ """
27
+ HF_TOKEN = os.getenv("HF_TOKEN")
28
+
29
+ if not HF_TOKEN:
30
+ return "Error: HF_TOKEN not found. Please set your Hugging Face token in environment variables."
31
+
32
+ # Initialize the client with Llama model
33
+ client = InferenceClient(
34
+ model=LLAMA_MODEL,
35
+ token=HF_TOKEN
36
+ )
37
+
38
+ # Enhanced system prompt for better responses
39
+ system_prompt = """You are a certified professional fitness trainer with expertise in creating personalized workout plans.
40
+ Always provide complete, detailed workout plans with:
41
+ - Clear day-by-day structure
42
+ - Specific exercises with sets, reps, and rest periods
43
+ - Warm-up and cool-down recommendations
44
+ - Safety considerations based on user's profile
45
+ When asked for a 5-day plan, ensure ALL 5 days are included with clear day headers."""
46
+
47
+ # Make the API call
48
+ response = client.chat_completion(
49
+ messages=[
50
+ {"role": "system", "content": system_prompt},
51
+ {"role": "user", "content": prompt}
52
+ ],
53
+ max_tokens=3000,
54
+ temperature=0.7,
55
+ top_p=0.95
56
+ )
57
+
58
+ # Extract and return the response
59
+ workout_plan = response.choices[0].message.content
60
+
61
+ # Verify if the response contains all 5 days
62
+ days_found = sum([f"Day {i}" in workout_plan for i in range(1, 6)])
63
+
64
+ if days_found < 5:
65
+ # If incomplete, try one more time with more explicit instruction
66
+ retry_prompt = prompt + "\n\nIMPORTANT: The previous response was incomplete. Please ensure ALL 5 days (Day 1 through Day 5) are included in the plan. Each day should be clearly marked with 'Day X' header and include 4-6 exercises."
67
 
68
+ retry_response = client.chat_completion(
 
69
  messages=[
70
  {"role": "system", "content": system_prompt},
71
+ {"role": "user", "content": retry_prompt}
72
  ],
73
+ max_tokens=3000,
74
+ temperature=0.7
75
+ )
76
+ workout_plan = retry_response.choices[0].message.content
77
+
78
+ return workout_plan
79
+
80
+ def query_ollama(prompt):
81
+ """
82
+ Query Llama via local Ollama (completely free, no API key needed)
83
+ """
84
+ try:
85
+ response = requests.post(
86
+ "http://localhost:11434/api/generate",
87
+ json={
88
+ "model": "llama3.2:3b", # or "llama3.2:1b" for lighter model
89
+ "prompt": f"""You are a certified professional fitness trainer. Create a comprehensive 5-day workout plan.
90
+
91
+ {prompt}
92
+
93
+ Provide a complete, detailed 5-day workout plan with clear day headers, exercises, sets, reps, and rest periods.""",
94
+ "stream": False,
95
+ "max_tokens": 3000,
96
+ "temperature": 0.7
97
+ }
98
  )
99
 
100
+ if response.status_code == 200:
101
+ return response.json()["response"]
102
+ else:
103
+ return f"Error: Ollama returned status code {response.status_code}"
 
 
 
 
 
104
 
105
+ except requests.exceptions.ConnectionError:
106
+ return "Error: Cannot connect to Ollama. Make sure Ollama is running locally (run 'ollama serve' in terminal)"
 
 
 
 
 
 
 
 
 
 
107
  except Exception as e:
108
+ return f"Error with Ollama: {str(e)}"
109
 
110
  def test_api_connection():
111
  """
112
  Test function to verify API connection
113
  """
114
  try:
115
+ if USE_LOCAL_OLLAMA:
116
+ # Test Ollama connection
117
+ response = requests.post(
118
+ "http://localhost:11434/api/generate",
119
+ json={
120
+ "model": "llama3.2:3b",
121
+ "prompt": "Say 'API connection successful' if you can read this.",
122
+ "stream": False,
123
+ "max_tokens": 50
124
+ }
125
+ )
126
+ if response.status_code == 200:
127
+ return True, "Ollama connection successful"
128
+ else:
129
+ return False, f"Ollama connection failed: {response.status_code}"
130
+ else:
131
+ # Test Hugging Face connection
132
+ HF_TOKEN = os.getenv("HF_TOKEN")
133
+ if not HF_TOKEN:
134
+ return False, "HF_TOKEN not found"
135
+
136
+ client = InferenceClient(
137
+ model=LLAMA_MODEL,
138
+ token=HF_TOKEN
139
+ )
140
+
141
+ response = client.chat_completion(
142
+ messages=[
143
+ {"role": "system", "content": "You are a helpful assistant."},
144
+ {"role": "user", "content": "Say 'API connection successful' if you can read this."}
145
+ ],
146
+ max_tokens=50,
147
+ temperature=0.1
148
+ )
149
+
150
+ return True, f"API connection successful (using {LLAMA_MODEL})"
151
 
152
  except Exception as e:
153
+ return False, f"API connection failed: {str(e)}"
154
+
155
+ def switch_model(model_name):
156
+ """
157
+ Switch to a different Llama model
158
+ """
159
+ global LLAMA_MODEL
160
+ LLAMA_MODEL = model_name
161
+ return f"Switched to {model_name}"
162
+
163
+ def set_ollama_mode(use_ollama):
164
+ """
165
+ Switch between Hugging Face API and local Ollama
166
+ """
167
+ global USE_LOCAL_OLLAMA
168
+ USE_LOCAL_OLLAMA = use_ollama
169
+ mode = "local Ollama" if use_ollama else "Hugging Face API"
170
+ return f"Switched to {mode} mode"