Update app.py
Browse files
app.py
CHANGED
|
@@ -10,43 +10,6 @@ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, TextIte
|
|
| 10 |
from transformers.image_utils import load_image
|
| 11 |
import time
|
| 12 |
|
| 13 |
-
# =============================================================================
|
| 14 |
-
# New imports and helper classes for image generation
|
| 15 |
-
# =============================================================================
|
| 16 |
-
try:
|
| 17 |
-
# We use Hugging Face’s InferenceClient as a generic image-generation API client.
|
| 18 |
-
from huggingface_hub import InferenceClient as HFInferenceClient
|
| 19 |
-
except ImportError:
|
| 20 |
-
HFInferenceClient = None
|
| 21 |
-
|
| 22 |
-
# A simple wrapper client for our primary image-generation space.
|
| 23 |
-
class Client:
|
| 24 |
-
def __init__(self, repo_id):
|
| 25 |
-
self.repo_id = repo_id
|
| 26 |
-
if HFInferenceClient is not None:
|
| 27 |
-
self.client = HFInferenceClient(repo_id)
|
| 28 |
-
else:
|
| 29 |
-
self.client = None
|
| 30 |
-
|
| 31 |
-
def predict(self, task, arg2, prompt, api_name):
|
| 32 |
-
if self.client is not None:
|
| 33 |
-
# Here we assume that calling the client with the prompt returns an image.
|
| 34 |
-
# (Depending on your API, you might need to adjust parameters.)
|
| 35 |
-
return self.client(prompt)
|
| 36 |
-
else:
|
| 37 |
-
raise Exception("HFInferenceClient not available")
|
| 38 |
-
|
| 39 |
-
def image_gen(prompt):
|
| 40 |
-
"""
|
| 41 |
-
Uses the STABLE-HAMSTER space to generate an image based on the prompt.
|
| 42 |
-
"""
|
| 43 |
-
client = Client("prithivMLmods/STABLE-HAMSTER")
|
| 44 |
-
return client.predict("Image Generation", None, prompt, api_name="/stable_hamster")
|
| 45 |
-
|
| 46 |
-
# =============================================================================
|
| 47 |
-
# Original Code (with modifications below)
|
| 48 |
-
# =============================================================================
|
| 49 |
-
|
| 50 |
DESCRIPTION = """
|
| 51 |
# QwQ Edge 💬
|
| 52 |
"""
|
|
@@ -123,46 +86,13 @@ def generate(
|
|
| 123 |
repetition_penalty: float = 1.2,
|
| 124 |
):
|
| 125 |
"""
|
| 126 |
-
Generates chatbot responses with support for multimodal input
|
| 127 |
If the query starts with an @tts command (e.g. "@tts1"), previous chat history is cleared.
|
| 128 |
-
If the query starts with an @image command, the image generation branch is used.
|
| 129 |
"""
|
| 130 |
text = input_dict["text"]
|
| 131 |
files = input_dict.get("files", [])
|
| 132 |
|
| 133 |
-
#
|
| 134 |
-
# NEW: Check for image generation command (@image)
|
| 135 |
-
# -------------------------------------------------------------------------
|
| 136 |
-
image_prefix = "@image"
|
| 137 |
-
if text.strip().lower().startswith(image_prefix):
|
| 138 |
-
# Remove the prefix and any extra whitespace
|
| 139 |
-
query = text[len(image_prefix):].strip()
|
| 140 |
-
yield "Generating Image, Please wait 10 sec..."
|
| 141 |
-
try:
|
| 142 |
-
image = image_gen(query)
|
| 143 |
-
# If the API returns a tuple (as in the snippet) use the second element;
|
| 144 |
-
# otherwise assume it returns an image directly.
|
| 145 |
-
if isinstance(image, (list, tuple)) and len(image) > 1:
|
| 146 |
-
yield gr.Image(image[1])
|
| 147 |
-
else:
|
| 148 |
-
yield gr.Image(image)
|
| 149 |
-
except Exception as e:
|
| 150 |
-
yield "Error in primary image generation, trying fallback..."
|
| 151 |
-
try:
|
| 152 |
-
# Use the fallback image generation client.
|
| 153 |
-
if HFInferenceClient is not None:
|
| 154 |
-
client_flux = HFInferenceClient("black-forest-labs/FLUX.1-schnell")
|
| 155 |
-
image = client_flux.text_to_image(query)
|
| 156 |
-
yield gr.Image(image)
|
| 157 |
-
else:
|
| 158 |
-
yield "Fallback client not available."
|
| 159 |
-
except Exception as fallback_error:
|
| 160 |
-
yield f"Error in image generation: {str(fallback_error)}"
|
| 161 |
-
return # End execution after processing the image-generation request.
|
| 162 |
-
|
| 163 |
-
# -------------------------------------------------------------------------
|
| 164 |
-
# Continue with the original processing (image files, TTS, or text conversation)
|
| 165 |
-
# -------------------------------------------------------------------------
|
| 166 |
if len(files) > 1:
|
| 167 |
images = [load_image(image) for image in files]
|
| 168 |
elif len(files) == 1:
|
|
@@ -173,7 +103,7 @@ def generate(
|
|
| 173 |
tts_prefix = "@tts"
|
| 174 |
is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
|
| 175 |
voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
|
| 176 |
-
|
| 177 |
if is_tts and voice_index:
|
| 178 |
voice = TTS_VOICES[voice_index - 1]
|
| 179 |
text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
|
|
@@ -258,7 +188,6 @@ demo = gr.ChatInterface(
|
|
| 258 |
["A train travels 60 kilometers per hour. If it travels for 5 hours, how far will it travel in total?"],
|
| 259 |
["Write a Python function to check if a number is prime."],
|
| 260 |
["@tts2 What causes rainbows to form?"],
|
| 261 |
-
["@image A beautiful sunset over a mountain range"],
|
| 262 |
],
|
| 263 |
cache_examples=False,
|
| 264 |
type="messages",
|
|
|
|
| 10 |
from transformers.image_utils import load_image
|
| 11 |
import time
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
DESCRIPTION = """
|
| 14 |
# QwQ Edge 💬
|
| 15 |
"""
|
|
|
|
| 86 |
repetition_penalty: float = 1.2,
|
| 87 |
):
|
| 88 |
"""
|
| 89 |
+
Generates chatbot responses with support for multimodal input and TTS.
|
| 90 |
If the query starts with an @tts command (e.g. "@tts1"), previous chat history is cleared.
|
|
|
|
| 91 |
"""
|
| 92 |
text = input_dict["text"]
|
| 93 |
files = input_dict.get("files", [])
|
| 94 |
|
| 95 |
+
# Process image files if provided
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
if len(files) > 1:
|
| 97 |
images = [load_image(image) for image in files]
|
| 98 |
elif len(files) == 1:
|
|
|
|
| 103 |
tts_prefix = "@tts"
|
| 104 |
is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
|
| 105 |
voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
|
| 106 |
+
|
| 107 |
if is_tts and voice_index:
|
| 108 |
voice = TTS_VOICES[voice_index - 1]
|
| 109 |
text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
|
|
|
|
| 188 |
["A train travels 60 kilometers per hour. If it travels for 5 hours, how far will it travel in total?"],
|
| 189 |
["Write a Python function to check if a number is prime."],
|
| 190 |
["@tts2 What causes rainbows to form?"],
|
|
|
|
| 191 |
],
|
| 192 |
cache_examples=False,
|
| 193 |
type="messages",
|