Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| from __future__ import annotations | |
| import base64 | |
| import json | |
| import os | |
| from dataclasses import dataclass, field | |
| from typing import Any, Dict, List, Optional, Union | |
| import requests | |
| import io | |
| import re | |
| import random | |
| from dataclasses import dataclass | |
| from typing import List, Dict, Callable, Optional, Tuple, Union | |
| import streamlit as st | |
| from PIL import Image, ImageDraw | |
| import pandas as pd | |
| from io import BytesIO | |
| class BaseAdapterError(RuntimeError): | |
| pass | |
| class BaseAdapter: | |
| provider: str | |
| model: str | |
| api_key: Optional[str] = None | |
| timeout: float = 60.0 | |
| base_url: Optional[str] = None | |
| extra_headers: Dict[str, str] = field(default_factory=dict) | |
| OPENAI = 'openai' | |
| ANTHROPIC = 'anthropic' | |
| GEMINI = 'gemini' | |
| MISTRAL = 'mistral' | |
| GROK = 'grok' | |
| COHERE = 'cohere' | |
| TOGETHER = 'together' | |
| providers = [OPENAI, | |
| ANTHROPIC, | |
| GEMINI, | |
| MISTRAL, | |
| GROK, | |
| COHERE, | |
| TOGETHER ] | |
| def __post_init__(self) -> None: | |
| self.provider = self.provider.lower().strip() | |
| if self.api_key is None: | |
| env_keys = { | |
| "openai": "OPENAI_API_KEY", | |
| "anthropic": "ANTHROPIC_API_KEY", | |
| "gemini": "GEMINI_API_KEY", | |
| "mistral": "MISTRAL_API_KEY", | |
| "grok" : "XAI_API_KEY", | |
| "cohere": "COHERE_API_KEY", | |
| "together": "TOGETHER_API_KEY", | |
| } | |
| env_var = env_keys.get(self.provider) | |
| if env_var: | |
| self.api_key = os.getenv(env_var) | |
| if not self.api_key and self.provider not in ("gemini",): | |
| raise BaseAdapterError(f"Missing api_key for {self.provider}. Set via environment.") | |
| def list_models(provider: str, api_key: Optional[str] = None, base_url: Optional[str] = None, timeout: float = 60.0) -> List[str]: | |
| p = provider.lower().strip() | |
| if p == "openai": | |
| url = (base_url or "https://api.openai.com") + "/v1/models" | |
| headers = {"Authorization": f"Bearer {api_key or os.getenv('OPENAI_API_KEY')}"} | |
| r = requests.get(url, headers=headers, timeout=timeout) | |
| BaseAdapter._raise_for_status_static(r) | |
| return [m["id"] for m in r.json().get("data", [])] | |
| if p == "anthropic": | |
| return [ | |
| "claude-3-5-sonnet-latest", | |
| "claude-3-5-haiku-latest", | |
| "claude-3-opus-latest", | |
| ] | |
| if p == "gemini": | |
| key = api_key or os.getenv("GEMINI_API_KEY") | |
| if not key: | |
| raise BaseAdapterError("Missing GEMINI_API_KEY.") | |
| url = (base_url or "https://generativelanguage.googleapis.com") + f"/v1beta/models?key={key}" | |
| r = requests.get(url, timeout=timeout) | |
| BaseAdapter._raise_for_status_static(r) | |
| return [m["name"] for m in r.json().get("models", [])] | |
| if p == "grok": | |
| key = api_key or os.getenv("XAI_API_KEY") | |
| url = (base_url or "https://api.x.ai") + "/v1/models" | |
| headers = {"Authorization": f"Bearer {key}"} | |
| r = requests.get(url, headers=headers, timeout=timeout) | |
| BaseAdapter._raise_for_status_static(r) | |
| return [m["id"] for m in r.json().get("data", [])] | |
| if p == "mistral": | |
| url = (base_url or "https://api.mistral.ai") + "/v1/models" | |
| headers = {"Authorization": f"Bearer {api_key or os.getenv('MISTRAL_API_KEY')}"} | |
| r = requests.get(url, headers=headers, timeout=timeout) | |
| BaseAdapter._raise_for_status_static(r) | |
| return [m["id"] for m in r.json().get("data", [])] | |
| if p == "cohere": | |
| url = (base_url or "https://api.cohere.ai") + "/v1/models" | |
| headers = {"Authorization": f"Bearer {api_key or os.getenv('COHERE_API_KEY')}"} | |
| r = requests.get(url, headers=headers, timeout=timeout) | |
| BaseAdapter._raise_for_status_static(r) | |
| return [m["name"] for m in r.json().get("models", [])] | |
| if p == "together": | |
| url = (base_url or "https://api.together.xyz") + "/v1/models" | |
| headers = {"Authorization": f"Bearer {api_key or os.getenv('TOGETHER_API_KEY')}"} | |
| r = requests.get(url, headers=headers, timeout=timeout) | |
| BaseAdapter._raise_for_status_static(r) | |
| return [m["id"] for m in r.json().get("data", [])] | |
| raise BaseAdapterError(f"Unsupported provider: {p}") | |
| # ---------- Utilities ---------- # | |
| def _raise_for_status_static(response: requests.Response) -> None: | |
| if 200 <= response.status_code < 300: | |
| return | |
| try: | |
| detail = response.json() | |
| msg = json.dumps(detail) | |
| except Exception: | |
| msg = response.text | |
| raise BaseAdapterError(f"HTTP {response.status_code}: {msg}") | |
| def _raise_for_status(self, response: requests.Response) -> None: | |
| return self._raise_for_status_static(response) | |
| def _detect_mime(b: bytes) -> str: | |
| if len(b) >= 8 and b[:8] == b"\x89PNG\r\n\x1a\n": | |
| return "image/png" | |
| if len(b) >= 3 and b[:3] == b"\xff\xd8\xff": | |
| return "image/jpeg" | |
| if len(b) >= 6 and b[:6] in (b"GIF87a", b"GIF89a"): | |
| return "image/gif" | |
| if len(b) >= 12 and b[8:12] == b"WEBP": | |
| return "image/webp" | |
| return "application/octet-stream" | |
| def _normalize_image(image: Union[str, bytes], default_mime: str = "image/png") -> tuple[str, str, str]: | |
| """Return (data_url, base64_str, mime_type) for the given image input. | |
| Accepts bytes, base64 string, data URL, or local file path. | |
| """ | |
| if isinstance(image, bytes): | |
| b64 = base64.b64encode(image).decode() | |
| mime = BaseAdapter._detect_mime(image) | |
| if mime == "application/octet-stream": | |
| mime = default_mime | |
| return f"data:{mime};base64,{b64}", b64, mime | |
| if isinstance(image, str): | |
| if image.startswith("data:"): | |
| header, b64 = image.split(",", 1) | |
| # data:image/png;base64,XXXX | |
| mime = header.split(";")[0].split(":", 1)[1] or default_mime | |
| return image, b64, mime | |
| if os.path.exists(image): | |
| with open(image, "rb") as f: | |
| raw = f.read() | |
| b64 = base64.b64encode(raw).decode() | |
| mime = BaseAdapter._detect_mime(raw) | |
| if mime == "application/octet-stream": | |
| mime = default_mime | |
| return f"data:{mime};base64,{b64}", b64, mime | |
| # assume bare base64 string | |
| b64 = image | |
| mime = default_mime | |
| return f"data:{mime};base64,{b64}", b64, mime | |
| raise BaseAdapterError("Unsupported image type; pass bytes, path, base64 string, or data URL.") | |
| class OpenaiAdapter(BaseAdapter): | |
| provider: str | |
| model: str | |
| api_key: Optional[str] = None | |
| timeout: float = 60.0 | |
| base_url: Optional[str] = None | |
| extra_headers: Dict[str, str] = field(default_factory=dict) | |
| def __init__(self, model_name): | |
| super().__init__(BaseAdapter.OPENAI, model_name) | |
| def generate(self, prompt: str, system: Optional[str] = None, image: Optional[List[Union[str, bytes, Image]] ] = None, **kwargs: Any) -> str: | |
| url = (self.base_url or "https://api.openai.com") + "/v1/chat/completions" | |
| headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", **self.extra_headers} | |
| messages = [] | |
| if system: | |
| messages.append({"role": "system", "content": system}) | |
| content = [{"type": "text", "text": prompt}] | |
| data_url = None | |
| if image is not None: | |
| if not isinstance(image, list): | |
| image = [image] | |
| for img in image: | |
| data_url, _b64, _mime = self._normalize_image(img, default_mime="image/png") | |
| content.append({"type": "image_url", "image_url": {"url": data_url}}) | |
| messages.append({"role": "user", "content": content}) | |
| payload = {"model": self.model, "messages": messages} | |
| r = requests.post(url, headers=headers, json=payload, timeout=self.timeout) | |
| self._raise_for_status(r) | |
| data = r.json() | |
| return data["choices"][0]["message"]["content"].strip() | |
| class AnthropicAdapter(BaseAdapter): | |
| provider: str | |
| model: str | |
| api_key: Optional[str] = None | |
| timeout: float = 60.0 | |
| base_url: Optional[str] = None | |
| extra_headers: Dict[str, str] = field(default_factory=dict) | |
| def __init__(self, model_name): | |
| super().__init__(BaseAdapter.ANTHROPIC, model_name) | |
| def generate(self, prompt: str, system: Optional[str] = None, image: Optional[List[Union[str, bytes, Image]] ] = None, **kwargs: Any) -> str: | |
| url = (self.base_url or "https://api.anthropic.com") + "/v1/messages" | |
| headers = { | |
| "x-api-key": self.api_key or "", | |
| "anthropic-version": "2023-06-01", | |
| "content-type": "application/json", | |
| **self.extra_headers, | |
| } | |
| content_items: List[Dict[str, Any]] = [{"type": "text", "text": prompt}] | |
| if image is not None: | |
| if not isinstance(image, list): | |
| image = [image] | |
| for img in image: | |
| _data_url, b64, mime = self._normalize_image(img, default_mime="image/png") | |
| content_items.append({ | |
| "type": "image", | |
| "source": { | |
| "type": "base64", | |
| "media_type": mime, | |
| "data": b64, | |
| }, | |
| }) | |
| payload: Dict[str, Any] = { | |
| "model": self.model, | |
| "max_tokens": kwargs.get("max_tokens", 1024), | |
| "messages": [{"role": "user", "content": content_items}], | |
| } | |
| if system: | |
| payload["system"] = system | |
| r = requests.post(url, headers=headers, json=payload, timeout=self.timeout) | |
| self._raise_for_status(r) | |
| data = r.json() | |
| parts = data.get("content", []) | |
| return "".join(p.get("text", "") for p in parts if p.get("type") == "text").strip() | |
| class GeminiAdapter(BaseAdapter): | |
| provider: str | |
| model: str | |
| api_key: Optional[str] = None | |
| timeout: float = 300.0 | |
| base_url: Optional[str] = None | |
| extra_headers: Dict[str, str] = field(default_factory=dict) | |
| def __init__(self, model_name): | |
| super().__init__(BaseAdapter.GEMINI, model_name) | |
| def generate(self, prompt: str, system: Optional[str] = None, image: Optional[List[Union[str, bytes, Image]] ] = None, **kwargs: Any) -> str: | |
| key = self.api_key or os.getenv("GEMINI_API_KEY") | |
| if not key: | |
| raise BaseAdapterError("Missing GEMINI_API_KEY.") | |
| base = self.base_url or "https://generativelanguage.googleapis.com" | |
| url = f"{base}/v1beta/models/{self.model}:generateContent?key={key}" | |
| headers = {"Content-Type": "application/json", **self.extra_headers} | |
| parts = [{"text": prompt}] | |
| if image is not None: | |
| if not isinstance(image,list): | |
| image = [image] | |
| for img in image: | |
| _data_url, b64, mime = self._normalize_image(img, default_mime="image/png") | |
| parts.append({"inline_data": {"mime_type": mime, "data": b64}}) | |
| contents = [{"role": "user", "parts": parts}] | |
| if system: | |
| contents.insert(0, {"role": "system", "parts": [{"text": system}]}) | |
| payload: Dict[str, Any] = {"contents": contents} | |
| if 'gemini-3-pro' in self.model: | |
| payload["generationConfig"] = { | |
| "thinkingConfig": { | |
| "thinkingLevel": "low" | |
| } | |
| } | |
| r = requests.post(url, headers=headers, json=payload, timeout=self.timeout) | |
| self._raise_for_status(r) | |
| data = r.json() | |
| return data["candidates"][0]["content"]["parts"][0]["text"].strip() | |
| class MistralAdapter(BaseAdapter): | |
| provider: str | |
| model: str | |
| api_key: Optional[str] = None | |
| timeout: float = 60.0 | |
| base_url: Optional[str] = None | |
| extra_headers: Dict[str, str] = field(default_factory=dict) | |
| def __init__(self, model_name): | |
| super().__init__(BaseAdapter.MISTRAL, model_name) | |
| def generate(self, prompt: str, system: Optional[str] = None, image: Optional[List[Union[str, bytes, Image]] ] = None, **kwargs: Any) -> str: | |
| url = (self.base_url or "https://api.mistral.ai") + "/v1/chat/completions" | |
| headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", **self.extra_headers} | |
| messages = [] | |
| if system: | |
| messages.append({"role": "system", "content": system}) | |
| content: List[Dict[str, Any]] = [{"type": "text", "text": prompt}] | |
| if image is not None: | |
| if not isinstance(image, list): | |
| image = [image] | |
| for img in image: | |
| data_url, _b64, _mime = self._normalize_image(img, default_mime="image/png") | |
| content.append({"type": "image_url", "image_url": {"url": data_url}}) | |
| messages.append({"role": "user", "content": content}) | |
| payload = {"model": self.model, "messages": messages} | |
| r = requests.post(url, headers=headers, json=payload, timeout=self.timeout) | |
| self._raise_for_status(r) | |
| data = r.json() | |
| return data["choices"][0]["message"]["content"].strip() | |
| class GrokAdapter(BaseAdapter): | |
| provider: str | |
| model: str | |
| api_key: Optional[str] = None | |
| timeout: float = 60.0 | |
| base_url: Optional[str] = None | |
| extra_headers: Dict[str, str] = field(default_factory=dict) | |
| def __init__(self, model_name): | |
| super().__init__(BaseAdapter.GROK, model_name) | |
| def generate(self, prompt: str, system: Optional[str] = None, image: Optional[List[Union[str, bytes, Image]] ] = None, **kwargs: Any) -> str: | |
| url = (self.base_url or "https://api.x.ai") + "/v1/chat/completions" | |
| headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", **self.extra_headers} | |
| messages = [] | |
| if system: | |
| messages.append({"role": "system", "content": system}) | |
| content = [{"type": "text", "text": prompt}] | |
| data_url = None | |
| if image is not None: | |
| if not isinstance(image, list): | |
| image = [image] | |
| for img in image: | |
| data_url, _b64, _mime = self._normalize_image(img, default_mime="image/png") | |
| content.append({"type": "image_url", "image_url": {"url": data_url}}) | |
| messages.append({"role": "user", "content": content}) | |
| payload = {"model": self.model, "messages": messages} | |
| r = requests.post(url, headers=headers, json=payload, timeout=self.timeout) | |
| self._raise_for_status(r) | |
| data = r.json() | |
| return data["choices"][0]["message"]["content"].strip() | |
| class TogetherAdapter(BaseAdapter): | |
| provider: str | |
| model: str | |
| api_key: Optional[str] = None | |
| timeout: float = 60.0 | |
| base_url: Optional[str] = None | |
| extra_headers: Dict[str, str] = field(default_factory=dict) | |
| def __init__(self, model_name): | |
| super().__init__(BaseAdapter.TOGETHER, model_name) | |
| def generate(self, prompt: str, system: Optional[str] = None, image: Optional[List[Union[str, bytes, Image]] ] = None, **kwargs: Any) -> str: | |
| url = (self.base_url or "https://api.together.xyz") + "/v1/chat/completions" | |
| headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", **self.extra_headers} | |
| messages = [] | |
| if system: | |
| messages.append({"role": "system", "content": system}) | |
| content: List[Dict[str, Any]] = [{"type": "text", "text": prompt}] | |
| if image is not None: | |
| if not isinstance(image, list): | |
| image = [image] | |
| for img in image: | |
| data_url, _b64, _mime = self._normalize_image(img, default_mime="image/png") | |
| content.append({"type": "image_url", "image_url": {"url": data_url}}) | |
| messages.append({"role": "user", "content": content}) | |
| payload = {"model": self.model, "messages": messages} | |
| r = requests.post(url, headers=headers, json=payload, timeout=self.timeout) | |
| self._raise_for_status(r) | |
| data = r.json() | |
| return data["choices"][0]["message"]["content"].strip() | |
| class CohereAdapter(BaseAdapter): | |
| provider: str | |
| model: str | |
| api_key: Optional[str] = None | |
| timeout: float = 60.0 | |
| base_url: Optional[str] = None | |
| extra_headers: Dict[str, str] = field(default_factory=dict) | |
| def __init__(self, model_name): | |
| super().__init__(BaseAdapter.COHERE, model_name) | |
| def generate(self, prompt: str, system: Optional[str] = None, image: Optional[List[Union[str, bytes, Image]] ] = None, **kwargs: Any) -> str: | |
| url = (self.base_url or "https://api.cohere.ai") + "/v1/chat" | |
| headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", **self.extra_headers} | |
| payload: Dict[str, Any] = {"model": self.model, "message": prompt} | |
| if system: | |
| payload["preamble"] = system | |
| if image is not None: | |
| if not isinstane(image, list): | |
| image = [image] | |
| for img in image: | |
| data_url, _b64, mime = self._normalize_image(img, default_mime="image/png") | |
| # Cohere chat supports attachments; we send a data URL to keep it dependency-light | |
| payload["attachments"] = [ | |
| { | |
| "type": "image", | |
| "image_url": data_url, | |
| "mime_type": mime, | |
| } | |
| ] | |
| r = requests.post(url, headers=headers, json=payload, timeout=self.timeout) | |
| self._raise_for_status(r) | |
| data = r.json() | |
| # Cohere responses can be under 'text' or 'message.content' | |
| return (data.get("text") or data.get("message", {}).get("content", [{}])[0].get("text", "")).strip() | |
| ''' | |
| @dataclass | |
| class UniLLM: | |
| provider: str | |
| model: str | |
| api_key: Optional[str] = None | |
| timeout: float = 60.0 | |
| base_url: Optional[str] = None | |
| extra_headers: Dict[str, str] = field(default_factory=dict) | |
| def __post_init__(self) -> None: | |
| self.provider = self.provider.lower().strip() | |
| if self.api_key is None: | |
| env_keys = { | |
| "openai": "OPENAI_API_KEY", | |
| "anthropic": "ANTHROPIC_API_KEY", | |
| "gemini": "GEMINI_API_KEY", | |
| "mistral": "MISTRAL_API_KEY", | |
| "cohere": "COHERE_API_KEY", | |
| "together": "TOGETHER_API_KEY", | |
| } | |
| env_var = env_keys.get(self.provider) | |
| if env_var: | |
| self.api_key = os.getenv(env_var) | |
| if not self.api_key and self.provider not in ("gemini",): | |
| raise UniLLMError(f"Missing api_key for {self.provider}. Set via environment.") | |
| # ---------- Public API ---------- # | |
| def generate(self, prompt: str, system: Optional[str] = None, image: Optional[Union[str, bytes]] = None, **kwargs: Any) -> str: | |
| p = self.provider | |
| if p == "openai": | |
| return self._openai_chat(prompt, system, image, **kwargs) | |
| if p == "anthropic": | |
| return self._anthropic_messages(prompt, system, image, **kwargs) | |
| if p == "gemini": | |
| return self._gemini_generate_content(prompt, system, image, **kwargs) | |
| if p == "mistral": | |
| return self._mistral_chat(prompt, system, image, **kwargs) | |
| if p == "cohere": | |
| return self._cohere_chat(prompt, system, image, **kwargs) | |
| if p == "together": | |
| return self._together_chat(prompt, system, image, **kwargs) | |
| raise UniLLMError(f"Unsupported provider: {p}") | |
| @staticmethod | |
| def list_models(provider: str, api_key: Optional[str] = None, base_url: Optional[str] = None, timeout: float = 60.0) -> List[str]: | |
| p = provider.lower().strip() | |
| if p == "openai": | |
| url = (base_url or "https://api.openai.com") + "/v1/models" | |
| headers = {"Authorization": f"Bearer {api_key or os.getenv('OPENAI_API_KEY')}"} | |
| r = requests.get(url, headers=headers, timeout=timeout) | |
| UniLLM._raise_for_status_static(r) | |
| return [m["id"] for m in r.json().get("data", [])] | |
| if p == "anthropic": | |
| return [ | |
| "claude-3-5-sonnet-latest", | |
| "claude-3-5-haiku-latest", | |
| "claude-3-opus-latest", | |
| ] | |
| if p == "gemini": | |
| key = api_key or os.getenv("GEMINI_API_KEY") | |
| if not key: | |
| raise UniLLMError("Missing GEMINI_API_KEY.") | |
| url = (base_url or "https://generativelanguage.googleapis.com") + f"/v1beta/models?key={key}" | |
| r = requests.get(url, timeout=timeout) | |
| UniLLM._raise_for_status_static(r) | |
| return [m["name"] for m in r.json().get("models", [])] | |
| if p == "mistral": | |
| url = (base_url or "https://api.mistral.ai") + "/v1/models" | |
| headers = {"Authorization": f"Bearer {api_key or os.getenv('MISTRAL_API_KEY')}"} | |
| r = requests.get(url, headers=headers, timeout=timeout) | |
| UniLLM._raise_for_status_static(r) | |
| return [m["id"] for m in r.json().get("data", [])] | |
| if p == "cohere": | |
| url = (base_url or "https://api.cohere.ai") + "/v1/models" | |
| headers = {"Authorization": f"Bearer {api_key or os.getenv('COHERE_API_KEY')}"} | |
| r = requests.get(url, headers=headers, timeout=timeout) | |
| UniLLM._raise_for_status_static(r) | |
| return [m["name"] for m in r.json().get("models", [])] | |
| if p == "together": | |
| url = (base_url or "https://api.together.xyz") + "/v1/models" | |
| headers = {"Authorization": f"Bearer {api_key or os.getenv('TOGETHER_API_KEY')}"} | |
| r = requests.get(url, headers=headers, timeout=timeout) | |
| UniLLM._raise_for_status_static(r) | |
| return [m["id"] for m in r.json().get("data", [])] | |
| raise UniLLMError(f"Unsupported provider: {p}") | |
| # ---------- Provider helpers ---------- # | |
| def _openai_chat(self, prompt: str, system: Optional[str], image: Optional[Union[str, bytes]], **kwargs: Any) -> str: | |
| url = (self.base_url or "https://api.openai.com") + "/v1/chat/completions" | |
| headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", **self.extra_headers} | |
| messages = [] | |
| if system: | |
| messages.append({"role": "system", "content": system}) | |
| content = [{"type": "text", "text": prompt}] | |
| data_url = None | |
| if image is not None: | |
| data_url, _b64, _mime = self._normalize_image(image, default_mime="image/png") | |
| content.append({"type": "image_url", "image_url": {"url": data_url}}) | |
| messages.append({"role": "user", "content": content}) | |
| payload = {"model": self.model, "messages": messages} | |
| r = requests.post(url, headers=headers, json=payload, timeout=self.timeout) | |
| self._raise_for_status(r) | |
| data = r.json() | |
| return data["choices"][0]["message"]["content"].strip() | |
| def _anthropic_messages(self, prompt: str, system: Optional[str], image: Optional[Union[str, bytes]], **kwargs: Any) -> str: | |
| url = (self.base_url or "https://api.anthropic.com") + "/v1/messages" | |
| headers = { | |
| "x-api-key": self.api_key or "", | |
| "anthropic-version": "2023-06-01", | |
| "content-type": "application/json", | |
| **self.extra_headers, | |
| } | |
| content_items: List[Dict[str, Any]] = [{"type": "text", "text": prompt}] | |
| if image is not None: | |
| _data_url, b64, mime = self._normalize_image(image, default_mime="image/png") | |
| content_items.append({ | |
| "type": "image", | |
| "source": { | |
| "type": "base64", | |
| "media_type": mime, | |
| "data": b64, | |
| }, | |
| }) | |
| payload: Dict[str, Any] = { | |
| "model": self.model, | |
| "max_tokens": kwargs.get("max_tokens", 1024), | |
| "messages": [{"role": "user", "content": content_items}], | |
| } | |
| if system: | |
| payload["system"] = system | |
| r = requests.post(url, headers=headers, json=payload, timeout=self.timeout) | |
| self._raise_for_status(r) | |
| data = r.json() | |
| parts = data.get("content", []) | |
| return "".join(p.get("text", "") for p in parts if p.get("type") == "text").strip() | |
| def _gemini_generate_content(self, prompt: str, system: Optional[str], image: Optional[Union[str, bytes]], **kwargs: Any) -> str: | |
| key = self.api_key or os.getenv("GEMINI_API_KEY") | |
| if not key: | |
| raise UniLLMError("Missing GEMINI_API_KEY.") | |
| base = self.base_url or "https://generativelanguage.googleapis.com" | |
| url = f"{base}/v1beta/models/{self.model}:generateContent?key={key}" | |
| headers = {"Content-Type": "application/json", **self.extra_headers} | |
| parts = [{"text": prompt}] | |
| if image is not None: | |
| _data_url, b64, mime = self._normalize_image(image, default_mime="image/png") | |
| parts.append({"inline_data": {"mime_type": mime, "data": b64}}) | |
| contents = [{"role": "user", "parts": parts}] | |
| if system: | |
| contents.insert(0, {"role": "system", "parts": [{"text": system}]}) | |
| payload: Dict[str, Any] = {"contents": contents} | |
| r = requests.post(url, headers=headers, json=payload, timeout=self.timeout) | |
| self._raise_for_status(r) | |
| data = r.json() | |
| return data["candidates"][0]["content"]["parts"][0]["text"].strip() | |
| def _mistral_chat(self, prompt: str, system: Optional[str], image: Optional[Union[str, bytes]], **kwargs: Any) -> str: | |
| url = (self.base_url or "https://api.mistral.ai") + "/v1/chat/completions" | |
| headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", **self.extra_headers} | |
| messages = [] | |
| if system: | |
| messages.append({"role": "system", "content": system}) | |
| content: List[Dict[str, Any]] = [{"type": "text", "text": prompt}] | |
| if image is not None: | |
| data_url, _b64, _mime = self._normalize_image(image, default_mime="image/png") | |
| content.append({"type": "image_url", "image_url": {"url": data_url}}) | |
| messages.append({"role": "user", "content": content}) | |
| payload = {"model": self.model, "messages": messages} | |
| r = requests.post(url, headers=headers, json=payload, timeout=self.timeout) | |
| self._raise_for_status(r) | |
| data = r.json() | |
| return data["choices"][0]["message"]["content"].strip() | |
| def _cohere_chat(self, prompt: str, system: Optional[str], image: Optional[Union[str, bytes]], **kwargs: Any) -> str: | |
| url = (self.base_url or "https://api.cohere.ai") + "/v1/chat" | |
| headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", **self.extra_headers} | |
| payload: Dict[str, Any] = {"model": self.model, "message": prompt} | |
| if system: | |
| payload["preamble"] = system | |
| if image is not None: | |
| data_url, _b64, mime = self._normalize_image(image, default_mime="image/png") | |
| # Cohere chat supports attachments; we send a data URL to keep it dependency-light | |
| payload["attachments"] = [ | |
| { | |
| "type": "image", | |
| "image_url": data_url, | |
| "mime_type": mime, | |
| } | |
| ] | |
| r = requests.post(url, headers=headers, json=payload, timeout=self.timeout) | |
| self._raise_for_status(r) | |
| data = r.json() | |
| # Cohere responses can be under 'text' or 'message.content' | |
| return (data.get("text") or data.get("message", {}).get("content", [{}])[0].get("text", "")).strip() | |
| def _together_chat(self, prompt: str, system: Optional[str], image: Optional[Union[str, bytes]], **kwargs: Any) -> str: | |
| url = (self.base_url or "https://api.together.xyz") + "/v1/chat/completions" | |
| headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", **self.extra_headers} | |
| messages = [] | |
| if system: | |
| messages.append({"role": "system", "content": system}) | |
| content: List[Dict[str, Any]] = [{"type": "text", "text": prompt}] | |
| if image is not None: | |
| data_url, _b64, _mime = self._normalize_image(image, default_mime="image/png") | |
| content.append({"type": "image_url", "image_url": {"url": data_url}}) | |
| messages.append({"role": "user", "content": content}) | |
| payload = {"model": self.model, "messages": messages} | |
| r = requests.post(url, headers=headers, json=payload, timeout=self.timeout) | |
| self._raise_for_status(r) | |
| data = r.json() | |
| return data["choices"][0]["message"]["content"].strip() | |
| # ---------- Utilities ---------- # | |
| @staticmethod | |
| def _raise_for_status_static(response: requests.Response) -> None: | |
| if 200 <= response.status_code < 300: | |
| return | |
| try: | |
| detail = response.json() | |
| msg = json.dumps(detail) | |
| except Exception: | |
| msg = response.text | |
| raise UniLLMError(f"HTTP {response.status_code}: {msg}") | |
| def _raise_for_status(self, response: requests.Response) -> None: | |
| return self._raise_for_status_static(response) | |
| @staticmethod | |
| def _detect_mime(b: bytes) -> str: | |
| if len(b) >= 8 and b[:8] == b"\x89PNG\r\n\x1a\n": | |
| return "image/png" | |
| if len(b) >= 3 and b[:3] == b"\xff\xd8\xff": | |
| return "image/jpeg" | |
| if len(b) >= 6 and b[:6] in (b"GIF87a", b"GIF89a"): | |
| return "image/gif" | |
| if len(b) >= 12 and b[8:12] == b"WEBP": | |
| return "image/webp" | |
| return "application/octet-stream" | |
| @staticmethod | |
| def _normalize_image(image: Union[str, bytes], default_mime: str = "image/png") -> tuple[str, str, str]: | |
| """Return (data_url, base64_str, mime_type) for the given image input. | |
| Accepts bytes, base64 string, data URL, or local file path. | |
| """ | |
| if isinstance(image, bytes): | |
| b64 = base64.b64encode(image).decode() | |
| mime = UniLLM._detect_mime(image) | |
| if mime == "application/octet-stream": | |
| mime = default_mime | |
| return f"data:{mime};base64,{b64}", b64, mime | |
| if isinstance(image, str): | |
| if image.startswith("data:"): | |
| header, b64 = image.split(",", 1) | |
| # data:image/png;base64,XXXX | |
| mime = header.split(";")[0].split(":", 1)[1] or default_mime | |
| return image, b64, mime | |
| if os.path.exists(image): | |
| with open(image, "rb") as f: | |
| raw = f.read() | |
| b64 = base64.b64encode(raw).decode() | |
| mime = BaseAdapter._detect_mime(raw) | |
| if mime == "application/octet-stream": | |
| mime = default_mime | |
| return f"data:{mime};base64,{b64}", b64, mime | |
| # assume bare base64 string | |
| b64 = image | |
| mime = default_mime | |
| return f"data:{mime};base64,{b64}", b64, mime | |
| raise BaseAdapter("Unsupported image type; pass bytes, path, base64 string, or data URL.") | |
| ''' |