luis-poe commited on
Commit
7e5f39d
·
verified ·
1 Parent(s): 7bbf670

Update compliment.py

Browse files
Files changed (1) hide show
  1. compliment.py +79 -153
compliment.py CHANGED
@@ -1,153 +1,79 @@
1
- import asyncio
2
- from groq import Groq
3
- import edge_tts
4
- import tempfile
5
- import os
6
-
7
- # Create a Groq client once at the module level to reuse across function calls
8
- client = Groq()
9
- async def text_to_speech(text, language):
10
- # Map language to Edge TTS voice
11
- if language.lower() == 'de':
12
- voice = 'de-DE-KatjaNeural' # German female voice
13
- else:
14
- voice = 'en-US-AriaNeural' # English female voice
15
-
16
- rate = "+0%"
17
- pitch = "+0Hz"
18
-
19
- communicate = edge_tts.Communicate(text, voice, rate=rate, pitch=pitch)
20
- with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
21
- tmp_path = tmp_file.name
22
- await communicate.save(tmp_path)
23
-
24
- with open(tmp_path, 'rb') as f:
25
- audio_data = f.read()
26
-
27
- os.remove(tmp_path)
28
- return audio_data
29
- async def generate_compliment(base64_image, compliment_prompt, model="llama-3.2-90b-vision-preview", max_tokens=300, temperature=0.5):
30
- """
31
- Generate a charming compliment about the most prominent person in an image.
32
-
33
- Args:
34
- - base64_image (str): The base64 encoded image.
35
- - compliment_prompt (str): The prompt to use for the chat completion.
36
- - model (str, optional): The model to use for the chat completion.
37
- - max_tokens (int, optional): The maximum number of tokens to generate.
38
- - temperature (float, optional): The sampling temperature.
39
-
40
- Returns:
41
- - str: The generated compliment.
42
- """
43
-
44
- # Prepare the messages payload
45
- messages = [
46
- {
47
- "role": "user",
48
- "content": [
49
- {"type": "text", "text": compliment_prompt},
50
- {
51
- "type": "image_url",
52
- "image_url": {
53
- "url": f"data:image/jpeg;base64,{base64_image}",
54
- },
55
- },
56
- ],
57
- }
58
- ]
59
-
60
- # Use asynchronous API call with streaming if available
61
- chat_stream = await client.chat.completions.create_async(
62
- max_tokens=max_tokens,
63
- temperature=temperature,
64
- messages=messages,
65
- model=model,
66
- stream=True # Enable streaming if the API supports it
67
- )
68
-
69
- # Collect the streamed response and process it
70
- compliment = ""
71
- async for chunk in chat_stream:
72
- if 'choices' in chunk:
73
- for choice in chunk['choices']:
74
- if 'delta' in choice and 'content' in choice['delta']:
75
- content = choice['delta']['content']
76
- compliment += content
77
- # Here, you can process each chunk as it arrives
78
- # For example, start partial TTS processing
79
- await asyncio.sleep(0) # Yield control to the event loop
80
-
81
- return compliment
82
-
83
- async def generate_compliment_and_audio(base64_image, compliment_prompt, model="llama-3.2-90b-vision-preview", max_tokens=300, temperature=0.5, tts_language='en'):
84
- """
85
- Generate a compliment and its audio, starting TTS processing as soon as possible to reduce latency.
86
-
87
- Args:
88
- - base64_image (str): The base64 encoded image.
89
- - compliment_prompt (str): The prompt for generating the compliment.
90
- - model (str, optional): The model to use for the chat completion.
91
- - max_tokens (int, optional): The maximum number of tokens to generate.
92
- - temperature (float, optional): The sampling temperature.
93
- - tts_language (str, optional): The language code for TTS.
94
-
95
- Returns:
96
- - Tuple[str, bytes]: The generated compliment and the audio data.
97
- """
98
-
99
- # Prepare the messages payload
100
- messages = [
101
- {
102
- "role": "user",
103
- "content": [
104
- {"type": "text", "text": compliment_prompt},
105
- {
106
- "type": "image_url",
107
- "image_url": {
108
- "url": f"data:image/jpeg;base64,{base64_image}",
109
- },
110
- },
111
- ],
112
- }
113
- ]
114
-
115
- # Variables to store the text and TTS task
116
- compliment_parts = []
117
- tts_task = None
118
-
119
- # Start the chat completion with streaming
120
- chat_stream = await client.chat.completions.create_async(
121
- max_tokens=max_tokens,
122
- temperature=temperature,
123
- messages=messages,
124
- model=model,
125
- stream=True
126
- )
127
-
128
- # Process the stream and start TTS as soon as possible
129
- async for chunk in chat_stream:
130
- if 'choices' in chunk:
131
- for choice in chunk['choices']:
132
- if 'delta' in choice and 'content' in choice['delta']:
133
- content = choice['delta']['content']
134
- compliment_parts.append(content)
135
-
136
- # Start TTS processing once we have enough content
137
- if tts_task is None and len(''.join(compliment_parts)) > 50:
138
- # Start the TTS processing asynchronously
139
- tts_task = asyncio.create_task(
140
- text_to_speech(''.join(compliment_parts), tts_language)
141
- )
142
- await asyncio.sleep(0) # Yield control to the event loop
143
-
144
- # If TTS hasn't started yet, start it now with the full compliment
145
- if tts_task is None:
146
- full_compliment = ''.join(compliment_parts)
147
- audio_data = await text_to_speech(full_compliment, tts_language)
148
- else:
149
- # Wait for the TTS task to complete
150
- audio_data = await tts_task
151
- full_compliment = ''.join(compliment_parts)
152
-
153
- return full_compliment, audio_data
 
1
+ import asyncio
2
+ from groq import Groq
3
+ import edge_tts
4
+ import tempfile
5
+ import os
6
+
7
+ # Create a Groq client once at the module level to reuse across function calls
8
+ client = Groq()
9
+
10
+ async def text_to_speech(text, language):
11
+ # Map language to Edge TTS voice
12
+ if language.lower() == 'de':
13
+ voice = 'de-DE-KatjaNeural' # German female voice
14
+ else:
15
+ voice = 'en-US-AriaNeural' # English female voice
16
+
17
+ rate = "+10%"
18
+ pitch = "+0Hz"
19
+
20
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
21
+ tmp_path = tmp_file.name
22
+
23
+ communicate = edge_tts.Communicate(text, voice, rate=rate, pitch=pitch)
24
+ await communicate.save(tmp_path)
25
+
26
+ # Do not delete the file yet; Gradio needs to access it
27
+ return tmp_path # Return the path to the audio file
28
+
29
+ async def generate_compliment_and_audio(base64_image, compliment_prompt, model="llama-3.2-90b-vision-preview", max_tokens=300, temperature=0.5, tts_language='en'):
30
+ """
31
+ Generate a compliment and its audio, starting TTS processing as soon as possible to reduce latency.
32
+
33
+ Args:
34
+ - base64_image (str): The base64 encoded image.
35
+ - compliment_prompt (str): The prompt for generating the compliment.
36
+ - model (str, optional): The model to use for the chat completion.
37
+ - max_tokens (int, optional): The maximum number of tokens to generate.
38
+ - temperature (float, optional): The sampling temperature.
39
+ - tts_language (str, optional): The language code for TTS.
40
+
41
+ Returns:
42
+ - Tuple[str, str]: The generated compliment and the audio file path.
43
+ """
44
+
45
+ # Prepare the messages payload
46
+ messages = [
47
+ {
48
+ "role": "user",
49
+ "content": [
50
+ {"type": "text", "text": compliment_prompt},
51
+ {
52
+ "type": "image_url",
53
+ "image_url": {
54
+ "url": f"data:image/jpeg;base64,{base64_image}",
55
+ },
56
+ },
57
+ ],
58
+ }
59
+ ]
60
+
61
+ # Since the Groq client does not support async methods, use asyncio.to_thread
62
+ def fetch_compliment():
63
+ return client.chat.completions.create(
64
+ max_tokens=max_tokens,
65
+ temperature=temperature,
66
+ messages=messages,
67
+ model=model,
68
+ )
69
+
70
+ # Call the synchronous function in a separate thread
71
+ chat_completion = await asyncio.to_thread(fetch_compliment)
72
+
73
+ # Extract the compliment
74
+ compliment = chat_completion.choices[0].message.content
75
+
76
+ # Start the TTS processing asynchronously
77
+ audio_file_path = await text_to_speech(compliment, tts_language)
78
+
79
+ return compliment, audio_file_path