Update app.py
Browse files
app.py
CHANGED
|
@@ -17,8 +17,12 @@ HF_TOKEN = os.getenv("HF_TOKEN")
|
|
| 17 |
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
| 18 |
IMAGE_GENERATION_SPACE_NAME = "stabilityai/stable-diffusion-3.5-large-turbo"
|
| 19 |
|
| 20 |
-
# Initialize Groq client
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
# LLM Models (free options)
|
| 24 |
LLM_MODELS = {
|
|
@@ -41,19 +45,23 @@ def generate_tutor_output(subject, difficulty, student_input, model):
|
|
| 41 |
Format your response as a JSON object with keys: "lesson", "question", "feedback"
|
| 42 |
"""
|
| 43 |
|
| 44 |
-
if model.startswith("mixtral"): # Groq model
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
else: # Hugging Face models
|
| 58 |
try:
|
| 59 |
client = Client("https://api-inference.huggingface.co/models/" + model, hf_token=HF_TOKEN)
|
|
@@ -61,16 +69,18 @@ def generate_tutor_output(subject, difficulty, student_input, model):
|
|
| 61 |
return json.loads(response)
|
| 62 |
except:
|
| 63 |
st.warning(f"HF model {model} failed, falling back to Mixtral.")
|
| 64 |
-
|
|
|
|
|
|
|
| 65 |
|
| 66 |
def generate_image(prompt, path='temp_image.png'):
|
| 67 |
try:
|
| 68 |
client = Client(IMAGE_GENERATION_SPACE_NAME, hf_token=HF_TOKEN)
|
| 69 |
result = client.predict(
|
| 70 |
prompt=prompt,
|
| 71 |
-
width=512,
|
| 72 |
height=512,
|
| 73 |
-
api_name="/predict"
|
| 74 |
)
|
| 75 |
image = Image.open(result)
|
| 76 |
image.save(path)
|
|
|
|
| 17 |
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
| 18 |
IMAGE_GENERATION_SPACE_NAME = "stabilityai/stable-diffusion-3.5-large-turbo"
|
| 19 |
|
| 20 |
+
# Initialize Groq client with minimal parameters
|
| 21 |
+
try:
|
| 22 |
+
groq_client = Groq(api_key=GROQ_API_KEY)
|
| 23 |
+
except Exception as e:
|
| 24 |
+
st.error(f"Failed to initialize Groq client: {e}")
|
| 25 |
+
groq_client = None
|
| 26 |
|
| 27 |
# LLM Models (free options)
|
| 28 |
LLM_MODELS = {
|
|
|
|
| 45 |
Format your response as a JSON object with keys: "lesson", "question", "feedback"
|
| 46 |
"""
|
| 47 |
|
| 48 |
+
if model.startswith("mixtral") and groq_client: # Groq model
|
| 49 |
+
try:
|
| 50 |
+
completion = groq_client.chat.completions.create(
|
| 51 |
+
messages=[{
|
| 52 |
+
"role": "system",
|
| 53 |
+
"content": f"You are the world's best AI tutor for {subject}, renowned for clear, engaging explanations."
|
| 54 |
+
}, {
|
| 55 |
+
"role": "user",
|
| 56 |
+
"content": prompt
|
| 57 |
+
}],
|
| 58 |
+
model=model,
|
| 59 |
+
max_tokens=1000
|
| 60 |
+
)
|
| 61 |
+
return json.loads(completion.choices[0].message.content)
|
| 62 |
+
except Exception as e:
|
| 63 |
+
st.error(f"Groq error: {e}")
|
| 64 |
+
return {"lesson": "Error generating lesson", "question": "N/A", "feedback": "N/A"}
|
| 65 |
else: # Hugging Face models
|
| 66 |
try:
|
| 67 |
client = Client("https://api-inference.huggingface.co/models/" + model, hf_token=HF_TOKEN)
|
|
|
|
| 69 |
return json.loads(response)
|
| 70 |
except:
|
| 71 |
st.warning(f"HF model {model} failed, falling back to Mixtral.")
|
| 72 |
+
if groq_client:
|
| 73 |
+
return generate_tutor_output(subject, difficulty, student_input, "mixtral-8x7b-32768")
|
| 74 |
+
return {"lesson": "Error generating lesson", "question": "N/A", "feedback": "N/A"}
|
| 75 |
|
| 76 |
def generate_image(prompt, path='temp_image.png'):
|
| 77 |
try:
|
| 78 |
client = Client(IMAGE_GENERATION_SPACE_NAME, hf_token=HF_TOKEN)
|
| 79 |
result = client.predict(
|
| 80 |
prompt=prompt,
|
| 81 |
+
width=512,
|
| 82 |
height=512,
|
| 83 |
+
api_name="/predict"
|
| 84 |
)
|
| 85 |
image = Image.open(result)
|
| 86 |
image.save(path)
|