Spaces:
Sleeping
Sleeping
| import os | |
| import hmac | |
| from typing import Iterable, Generator, Optional, Dict, Any | |
| import gradio as gr | |
| # --- Optional: install SDKs in your Space (add to requirements.txt) --- | |
| # openai>=1.40.0 | |
| # google-genai>=0.3.0 | |
| # OpenAI-compatible SDK (used for OpenAI and DeepSeek) | |
| from openai import OpenAI | |
| # Gemini SDK | |
| try: | |
| from google import genai | |
| from google.genai import types as genai_types | |
| except ImportError: | |
| genai = None | |
| genai_types = None | |
| # -------- Helpers: secret/backdoor resolution (optional pattern you already like) -------- | |
| def _timing_safe_eq(a: str, b: str) -> bool: | |
| return hmac.compare_digest(a, b) | |
| def _resolve_key(user_value: str, secret_gate_name: str, secret_payload_name: str) -> Optional[str]: | |
| """ | |
| If user_value is an 8-digit code that matches ENV[secret_gate_name], return ENV[secret_payload_name]. | |
| Else return user_value (or None if empty). | |
| """ | |
| user_value = (user_value or "").strip() | |
| backdoor_code = (os.getenv(secret_gate_name) or "").strip() | |
| if user_value.isdigit() and len(user_value) == 8 and backdoor_code: | |
| if _timing_safe_eq(user_value, backdoor_code): | |
| return (os.getenv(secret_payload_name) or "").strip() or None | |
| return user_value or None | |
| # -------- Providers -------- | |
| OPENAI_MODELS = [ | |
| "gpt-4o", | |
| "o4-mini", # aka gpt-4o-mini family; see OpenAI docs | |
| "gpt-3.5-turbo", | |
| ] | |
| GEMINI_MODELS = [ | |
| "gemini-1.5-flash", | |
| "gemini-2.0-flash", # available via Gemini API / Vertex; keep synced with docs | |
| ] | |
| DEEPSEEK_MODELS = [ | |
| "deepseek-chat", | |
| ] | |
| PROVIDERS = { | |
| "OpenAI": OPENAI_MODELS, | |
| "Gemini": GEMINI_MODELS, | |
| "DeepSeek": DEEPSEEK_MODELS, | |
| } | |
| # -------- Streaming runners per provider -------- | |
| def stream_openai_like( | |
| model: str, | |
| prompt: str, | |
| temperature: float, | |
| top_p: float, | |
| max_tokens: int, | |
| seed: Optional[int], | |
| api_key: str, | |
| base_url: Optional[str] = None, | |
| ) -> Generator[str, None, None]: | |
| """ | |
| Streams Chat Completions from OpenAI or any OpenAI-compatible endpoint (DeepSeek). | |
| """ | |
| client = OpenAI(api_key=api_key, base_url=base_url) if base_url else OpenAI(api_key=api_key) | |
| # Chat format even for simple prompting | |
| kwargs: Dict[str, Any] = dict( | |
| model=model, | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=temperature, | |
| top_p=top_p, | |
| max_tokens=max_tokens, | |
| stream=True, | |
| ) | |
| if seed is not None: | |
| kwargs["seed"] = seed # supported in recent OpenAI SDKs | |
| response_text = "" | |
| try: | |
| stream = client.chat.completions.create(**kwargs) | |
| for part in stream: | |
| delta = part.choices[0].delta if hasattr(part.choices[0], "delta") else None | |
| # Some SDKs use .delta, some have .message or .text in streaming chunks | |
| token = "" | |
| if delta and getattr(delta, "content", None): | |
| token = delta.content | |
| elif hasattr(part.choices[0], "message") and part.choices[0].message.content: | |
| token = part.choices[0].message.content | |
| elif hasattr(part.choices[0], "text") and part.choices[0].text: # fallback | |
| token = part.choices[0].text | |
| if token: | |
| response_text += token | |
| yield response_text | |
| except Exception as e: | |
| yield f"❌ OpenAI-compatible error: {type(e).__name__}: {e}" | |
| def stream_gemini( | |
| model: str, | |
| prompt: str, | |
| temperature: float, | |
| top_p: float, | |
| max_tokens: int, | |
| seed: Optional[int], | |
| api_key: str, | |
| ) -> Generator[str, None, None]: | |
| """ | |
| Streams from Google Gemini via google-genai SDK. | |
| Uses the correct streaming interface: client.models.generate_content_stream(...) | |
| """ | |
| if genai is None: | |
| yield "❌ Gemini SDK not installed. Add `google-genai` to requirements.txt." | |
| return | |
| client = genai.Client(api_key=api_key) | |
| cfg_kwargs: Dict[str, Any] = { | |
| "temperature": float(temperature), | |
| "top_p": float(top_p), | |
| "max_output_tokens": int(max_tokens), | |
| } | |
| if seed is not None: | |
| cfg_kwargs["seed"] = int(seed) | |
| response_text = "" | |
| try: | |
| stream = client.models.generate_content_stream( | |
| model=model, | |
| contents=prompt, | |
| config=genai_types.GenerateContentConfig(**cfg_kwargs), | |
| ) | |
| for chunk in stream: | |
| txt = getattr(chunk, "text", None) | |
| if txt: | |
| response_text += txt | |
| yield response_text | |
| # Some drivers may expose a final aggregate; safe no-op if absent. | |
| final = getattr(stream, "text", None) | |
| if final and final not in response_text: | |
| response_text += final | |
| yield response_text | |
| except Exception as e: | |
| yield f"❌ Gemini error: {type(e).__name__}: {e}" | |
| # -------- Gradio callback -------- | |
| def multi_llm_complete( | |
| provider: str, | |
| model: str, | |
| prompt: str, | |
| max_tokens: int, | |
| temperature: float, | |
| top_p: float, | |
| seed_text: str, | |
| # API keys (user enters). You can also support the 8-digit backdoor pattern per provider: | |
| openai_key_input: str, | |
| gemini_key_input: str, | |
| deepseek_key_input: str, | |
| ): | |
| # Resolve seed | |
| seed: Optional[int] = None | |
| if seed_text and str(seed_text).strip().isdigit(): | |
| seed = int(str(seed_text).strip()) | |
| # Resolve keys (optionally allow an 8-digit backdoor per provider) | |
| if provider == "OpenAI": | |
| api_key = _resolve_key(openai_key_input, "OPENAI_BACKDOOR_KEY", "OPENAI_KEY") or "" | |
| if not api_key: | |
| yield "⚠️ Enter a valid OpenAI API key." | |
| return | |
| # Stream via OpenAI | |
| for chunk in stream_openai_like( | |
| model=model, | |
| prompt=prompt, | |
| temperature=temperature, | |
| top_p=top_p, | |
| max_tokens=max_tokens, | |
| seed=seed, | |
| api_key=api_key, | |
| base_url=None, | |
| ): | |
| yield chunk | |
| elif provider == "Gemini": | |
| api_key = _resolve_key(gemini_key_input, "GEMINI_BACKDOOR_KEY", "GEMINI_KEY") or "" | |
| if not api_key: | |
| yield "⚠️ Enter a valid Gemini API key." | |
| return | |
| for chunk in stream_gemini( | |
| model=model, | |
| prompt=prompt, | |
| temperature=temperature, | |
| top_p=top_p, | |
| max_tokens=max_tokens, | |
| seed=seed, | |
| api_key=api_key, | |
| ): | |
| yield chunk | |
| elif provider == "DeepSeek": | |
| api_key = _resolve_key(deepseek_key_input, "DEEPSEEK_BACKDOOR_KEY", "DEEPSEEK_KEY") or "" | |
| if not api_key: | |
| yield "⚠️ Enter a valid DeepSeek API key." | |
| return | |
| # DeepSeek: OpenAI-compatible endpoint | |
| for chunk in stream_openai_like( | |
| model=model, | |
| prompt=prompt, | |
| temperature=temperature, | |
| top_p=top_p, | |
| max_tokens=max_tokens, | |
| seed=seed, | |
| api_key=api_key, | |
| base_url="https://api.deepseek.com", | |
| ): | |
| yield chunk | |
| else: | |
| yield "❌ Unknown provider selection." | |
| # -------- UI -------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## 🔀 Multi-LLM Chat (OpenAI • Gemini • DeepSeek)") | |
| gr.Markdown( | |
| "Pick a provider & model, enter the provider’s API key, tune params, and stream the reply. " | |
| "Seed (if supported) improves reproducibility." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| provider = gr.Dropdown( | |
| choices=list(PROVIDERS.keys()), | |
| value="OpenAI", | |
| label="Provider", | |
| ) | |
| model = gr.Dropdown( | |
| choices=PROVIDERS["OpenAI"], | |
| value="gpt-4o", | |
| label="Model", | |
| ) | |
| def _update_models(p): | |
| return gr.update(choices=PROVIDERS[p], value=PROVIDERS[p][0]) | |
| provider.change(_update_models, inputs=provider, outputs=model) | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Ask anything…", | |
| lines=6, | |
| ) | |
| max_tokens = gr.Slider(1, 4096, value=512, step=1, label="Max tokens") | |
| temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.1, label="Temperature") | |
| top_p = gr.Slider(0.0, 1.0, value=1.0, step=0.01, label="Top-p") | |
| seed = gr.Textbox(label="🎲 Seed (optional integer)", placeholder="e.g., 42") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### API Keys (per provider)") | |
| openai_key = gr.Textbox(label="OpenAI API Key", type="password", placeholder="sk-... or 8-digit passcode") | |
| gemini_key = gr.Textbox(label="Gemini API Key", type="password", placeholder="AI Studio key or 8-digit passcode") | |
| deepseek_key = gr.Textbox(label="DeepSeek API Key", type="password", placeholder="ds-... or 8-digit passcode") | |
| submit = gr.Button("▶️ Generate", variant="primary") | |
| output = gr.Textbox(label="Response", lines=18) | |
| submit.click( | |
| fn=multi_llm_complete, | |
| inputs=[ | |
| provider, model, prompt, | |
| max_tokens, temperature, top_p, seed, | |
| openai_key, gemini_key, deepseek_key | |
| ], | |
| outputs=output, | |
| ) | |
| if __name__ == "__main__": | |
| # On Spaces, consider server_name='0.0.0.0' | |
| demo.launch() | |