Update app.py
Browse files
app.py
CHANGED
|
@@ -15,20 +15,21 @@ load_dotenv()
|
|
| 15 |
# Constants
|
| 16 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 17 |
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
| 18 |
-
|
|
|
|
| 19 |
|
| 20 |
-
# Initialize Groq client
|
| 21 |
try:
|
| 22 |
groq_client = Groq(api_key=GROQ_API_KEY)
|
| 23 |
except Exception as e:
|
| 24 |
st.error(f"Failed to initialize Groq client: {e}")
|
| 25 |
groq_client = None
|
| 26 |
|
| 27 |
-
# LLM Models
|
| 28 |
LLM_MODELS = {
|
| 29 |
"Mixtral 8x7B (Groq)": "mixtral-8x7b-32768",
|
| 30 |
"Mistral 7B (HF)": "mistralai/Mixtral-7B-Instruct-v0.1",
|
| 31 |
-
"LLaMA 13B (HF)": "meta-llama/Llama-13b-hf"
|
| 32 |
}
|
| 33 |
|
| 34 |
# Utility Functions
|
|
@@ -45,7 +46,7 @@ def generate_tutor_output(subject, difficulty, student_input, model):
|
|
| 45 |
Format your response as a JSON object with keys: "lesson", "question", "feedback"
|
| 46 |
"""
|
| 47 |
|
| 48 |
-
if model.startswith("mixtral") and groq_client:
|
| 49 |
try:
|
| 50 |
completion = groq_client.chat.completions.create(
|
| 51 |
messages=[{
|
|
@@ -61,8 +62,8 @@ def generate_tutor_output(subject, difficulty, student_input, model):
|
|
| 61 |
return json.loads(completion.choices[0].message.content)
|
| 62 |
except Exception as e:
|
| 63 |
st.error(f"Groq error: {e}")
|
| 64 |
-
return {"lesson": "
|
| 65 |
-
else:
|
| 66 |
try:
|
| 67 |
client = Client("https://api-inference.huggingface.co/models/" + model, hf_token=HF_TOKEN)
|
| 68 |
response = client.predict(prompt, api_name="/generate")
|
|
@@ -71,18 +72,16 @@ def generate_tutor_output(subject, difficulty, student_input, model):
|
|
| 71 |
st.warning(f"HF model {model} failed, falling back to Mixtral.")
|
| 72 |
if groq_client:
|
| 73 |
return generate_tutor_output(subject, difficulty, student_input, "mixtral-8x7b-32768")
|
| 74 |
-
return {"lesson": "
|
| 75 |
|
| 76 |
def generate_image(prompt, path='temp_image.png'):
|
| 77 |
try:
|
| 78 |
-
client = Client(
|
| 79 |
-
result = client.predict(
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
)
|
| 85 |
-
image = Image.open(result)
|
| 86 |
image.save(path)
|
| 87 |
return path
|
| 88 |
except Exception as e:
|
|
@@ -91,6 +90,9 @@ def generate_image(prompt, path='temp_image.png'):
|
|
| 91 |
|
| 92 |
def generate_video(images, audio_text, language, speaker, path='temp_video.mp4'):
|
| 93 |
try:
|
|
|
|
|
|
|
|
|
|
| 94 |
audio_client = Client("habib926653/Multilingual-TTS")
|
| 95 |
audio_result = audio_client.predict(
|
| 96 |
text=audio_text,
|
|
@@ -106,8 +108,11 @@ def generate_video(images, audio_text, language, speaker, path='temp_video.mp4')
|
|
| 106 |
f.write(audio_bytes)
|
| 107 |
|
| 108 |
audio_clip = mp.AudioFileClip(audio_path)
|
| 109 |
-
duration_per_image = audio_clip.duration / len(images)
|
| 110 |
image_clips = [mp.ImageClip(img).set_duration(duration_per_image) for img in images if img]
|
|
|
|
|
|
|
|
|
|
| 111 |
video = mp.concatenate_videoclips(image_clips, method="compose").set_audio(audio_clip)
|
| 112 |
video.write_videofile(path, fps=24, codec='libx264')
|
| 113 |
return path
|
|
|
|
| 15 |
# Constants
|
| 16 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 17 |
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
| 18 |
+
# Switching to HF Inference API for stability
|
| 19 |
+
IMAGE_GENERATION_API = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-2-1"
|
| 20 |
|
| 21 |
+
# Initialize Groq client
|
| 22 |
try:
|
| 23 |
groq_client = Groq(api_key=GROQ_API_KEY)
|
| 24 |
except Exception as e:
|
| 25 |
st.error(f"Failed to initialize Groq client: {e}")
|
| 26 |
groq_client = None
|
| 27 |
|
| 28 |
+
# LLM Models
|
| 29 |
LLM_MODELS = {
|
| 30 |
"Mixtral 8x7B (Groq)": "mixtral-8x7b-32768",
|
| 31 |
"Mistral 7B (HF)": "mistralai/Mixtral-7B-Instruct-v0.1",
|
| 32 |
+
"LLaMA 13B (HF)": "meta-llama/Llama-13b-hf"
|
| 33 |
}
|
| 34 |
|
| 35 |
# Utility Functions
|
|
|
|
| 46 |
Format your response as a JSON object with keys: "lesson", "question", "feedback"
|
| 47 |
"""
|
| 48 |
|
| 49 |
+
if model.startswith("mixtral") and groq_client:
|
| 50 |
try:
|
| 51 |
completion = groq_client.chat.completions.create(
|
| 52 |
messages=[{
|
|
|
|
| 62 |
return json.loads(completion.choices[0].message.content)
|
| 63 |
except Exception as e:
|
| 64 |
st.error(f"Groq error: {e}")
|
| 65 |
+
return {"lesson": "Sorry, unable to generate lesson due to API issue.", "question": "N/A", "feedback": "Please try again or check your input."}
|
| 66 |
+
else:
|
| 67 |
try:
|
| 68 |
client = Client("https://api-inference.huggingface.co/models/" + model, hf_token=HF_TOKEN)
|
| 69 |
response = client.predict(prompt, api_name="/generate")
|
|
|
|
| 72 |
st.warning(f"HF model {model} failed, falling back to Mixtral.")
|
| 73 |
if groq_client:
|
| 74 |
return generate_tutor_output(subject, difficulty, student_input, "mixtral-8x7b-32768")
|
| 75 |
+
return {"lesson": "Sorry, unable to generate lesson.", "question": "N/A", "feedback": "N/A"}
|
| 76 |
|
| 77 |
def generate_image(prompt, path='temp_image.png'):
|
| 78 |
try:
|
| 79 |
+
client = Client(IMAGE_GENERATION_API, hf_token=HF_TOKEN)
|
| 80 |
+
result = client.predict(prompt, api_name="/predict")
|
| 81 |
+
if isinstance(result, str): # Handle file path or binary data
|
| 82 |
+
image = Image.open(result)
|
| 83 |
+
else:
|
| 84 |
+
image = Image.open(result)
|
|
|
|
|
|
|
| 85 |
image.save(path)
|
| 86 |
return path
|
| 87 |
except Exception as e:
|
|
|
|
| 90 |
|
| 91 |
def generate_video(images, audio_text, language, speaker, path='temp_video.mp4'):
|
| 92 |
try:
|
| 93 |
+
if not images or all(img is None for img in images):
|
| 94 |
+
st.error("No valid images to create video.")
|
| 95 |
+
return None
|
| 96 |
audio_client = Client("habib926653/Multilingual-TTS")
|
| 97 |
audio_result = audio_client.predict(
|
| 98 |
text=audio_text,
|
|
|
|
| 108 |
f.write(audio_bytes)
|
| 109 |
|
| 110 |
audio_clip = mp.AudioFileClip(audio_path)
|
| 111 |
+
duration_per_image = audio_clip.duration / len([img for img in images if img])
|
| 112 |
image_clips = [mp.ImageClip(img).set_duration(duration_per_image) for img in images if img]
|
| 113 |
+
if not image_clips:
|
| 114 |
+
st.error("No image clips generated.")
|
| 115 |
+
return None
|
| 116 |
video = mp.concatenate_videoclips(image_clips, method="compose").set_audio(audio_clip)
|
| 117 |
video.write_videofile(path, fps=24, codec='libx264')
|
| 118 |
return path
|