Spaces:
Running
Running
| import json | |
| import uuid | |
| import os | |
| from pathlib import Path | |
| from typing import Dict, Optional, Tuple, Iterator, List, Generator, Any | |
| import struct | |
| import requests | |
| class StreamTracker: | |
| def __init__(self): | |
| self.has_content = False | |
| def track(self, gen: Generator[str, None, None]) -> Generator[str, None, None]: | |
| for item in gen: | |
| if item: | |
| self.has_content = True | |
| yield item | |
| def _get_proxies() -> Optional[Dict[str, str]]: | |
| proxy = os.getenv("HTTP_PROXY", "").strip() | |
| if proxy: | |
| return {"http": proxy, "https": proxy} | |
| return None | |
| BASE_DIR = Path(__file__).resolve().parent | |
| TEMPLATE_PATH = BASE_DIR / "templates" / "streaming_request.json" | |
| def load_template() -> Tuple[str, Dict[str, str], Dict[str, Any]]: | |
| data = json.loads(TEMPLATE_PATH.read_text(encoding="utf-8")) | |
| url, headers, body = data | |
| assert isinstance(url, str) and isinstance(headers, dict) and isinstance(body, dict) | |
| return url, headers, body | |
| def _merge_headers(as_log: Dict[str, str], bearer_token: str) -> Dict[str, str]: | |
| headers = dict(as_log) | |
| for k in list(headers.keys()): | |
| kl = k.lower() | |
| if kl in ("content-length","host","connection","transfer-encoding"): | |
| headers.pop(k, None) | |
| def set_header(name: str, value: str): | |
| for key in list(headers.keys()): | |
| if key.lower() == name.lower(): | |
| del headers[key] | |
| headers[name] = value | |
| set_header("Authorization", f"Bearer {bearer_token}") | |
| set_header("amz-sdk-invocation-id", str(uuid.uuid4())) | |
| return headers | |
| def _parse_event_headers(raw: bytes) -> Dict[str, object]: | |
| headers: Dict[str, object] = {} | |
| i = 0 | |
| n = len(raw) | |
| while i < n: | |
| if i + 1 > n: | |
| break | |
| name_len = raw[i] | |
| i += 1 | |
| if i + name_len + 1 > n: | |
| break | |
| name = raw[i : i + name_len].decode("utf-8", errors="ignore") | |
| i += name_len | |
| htype = raw[i] | |
| i += 1 | |
| if htype == 0: | |
| val = True | |
| elif htype == 1: | |
| val = False | |
| elif htype == 2: | |
| if i + 1 > n: break | |
| val = raw[i]; i += 1 | |
| elif htype == 3: | |
| if i + 2 > n: break | |
| val = int.from_bytes(raw[i:i+2],"big",signed=True); i += 2 | |
| elif htype == 4: | |
| if i + 4 > n: break | |
| val = int.from_bytes(raw[i:i+4],"big",signed=True); i += 4 | |
| elif htype == 5: | |
| if i + 8 > n: break | |
| val = int.from_bytes(raw[i:i+8],"big",signed=True); i += 8 | |
| elif htype == 6: | |
| if i + 2 > n: break | |
| l = int.from_bytes(raw[i:i+2],"big"); i += 2 | |
| if i + l > n: break | |
| val = raw[i:i+l]; i += l | |
| elif htype == 7: | |
| if i + 2 > n: break | |
| l = int.from_bytes(raw[i:i+2],"big"); i += 2 | |
| if i + l > n: break | |
| val = raw[i:i+l].decode("utf-8", errors="ignore"); i += l | |
| elif htype == 8: | |
| if i + 8 > n: break | |
| val = int.from_bytes(raw[i:i+8],"big",signed=False); i += 8 | |
| elif htype == 9: | |
| if i + 16 > n: break | |
| import uuid as _uuid | |
| val = str(_uuid.UUID(bytes=bytes(raw[i:i+16]))); i += 16 | |
| else: | |
| break | |
| headers[name] = val | |
| return headers | |
| class AwsEventStreamParser: | |
| def __init__(self): | |
| self._buf = bytearray() | |
| def feed(self, data: bytes) -> List[Tuple[Dict[str, object], bytes]]: | |
| if not data: | |
| return [] | |
| self._buf.extend(data) | |
| out: List[Tuple[Dict[str, object], bytes]] = [] | |
| while True: | |
| if len(self._buf) < 12: | |
| break | |
| total_len, headers_len, _prelude_crc = struct.unpack(">I I I", self._buf[:12]) | |
| if total_len < 16 or headers_len > total_len: | |
| self._buf.pop(0) | |
| continue | |
| if len(self._buf) < total_len: | |
| break | |
| msg = bytes(self._buf[:total_len]) | |
| del self._buf[:total_len] | |
| headers_raw = msg[12:12+headers_len] | |
| payload = msg[12+headers_len: total_len-4] | |
| headers = _parse_event_headers(headers_raw) | |
| out.append((headers, payload)) | |
| return out | |
| def _try_decode_event_payload(payload: bytes) -> Optional[dict]: | |
| try: | |
| return json.loads(payload.decode("utf-8")) | |
| except Exception: | |
| return None | |
| def _extract_text_from_event(ev: dict) -> Optional[str]: | |
| for key in ("assistantResponseEvent","assistantMessage","message","delta","data"): | |
| if key in ev and isinstance(ev[key], dict): | |
| inner = ev[key] | |
| if isinstance(inner.get("content"), str) and inner.get("content"): | |
| return inner["content"] | |
| if isinstance(ev.get("content"), str) and ev.get("content"): | |
| return ev["content"] | |
| for list_key in ("chunks","content"): | |
| if isinstance(ev.get(list_key), list): | |
| buf = [] | |
| for item in ev[list_key]: | |
| if isinstance(item, dict): | |
| if isinstance(item.get("content"), str): | |
| buf.append(item["content"]) | |
| elif isinstance(item.get("text"), str): | |
| buf.append(item["text"]) | |
| elif isinstance(item, str): | |
| buf.append(item) | |
| if buf: | |
| return "".join(buf) | |
| for k in ("text","delta","payload"): | |
| v = ev.get(k) | |
| if isinstance(v, str) and v: | |
| return v | |
| return None | |
| def openai_messages_to_text(messages: List[Dict[str, Any]]) -> str: | |
| lines: List[str] = [] | |
| for m in messages: | |
| role = m.get("role","user") | |
| content = m.get("content","") | |
| if isinstance(content, list): | |
| parts = [] | |
| for seg in content: | |
| if isinstance(seg, dict) and isinstance(seg.get("text"), str): | |
| parts.append(seg["text"]) | |
| elif isinstance(seg, str): | |
| parts.append(seg) | |
| content = "\n".join(parts) | |
| elif not isinstance(content, str): | |
| content = str(content) | |
| lines.append(f"{role}:\n{content}") | |
| return "\n\n".join(lines) | |
| def inject_history(body_json: Dict[str, Any], history_text: str) -> None: | |
| try: | |
| cur = body_json["conversationState"]["currentMessage"]["userInputMessage"] | |
| content = cur.get("content","") | |
| if isinstance(content, str): | |
| cur["content"] = content.replace("你好,你必须讲个故事", history_text) | |
| except Exception: | |
| pass | |
| def inject_model(body_json: Dict[str, Any], model: Optional[str]) -> None: | |
| if not model: | |
| return | |
| try: | |
| body_json["conversationState"]["currentMessage"]["userInputMessage"]["modelId"] = model | |
| except Exception: | |
| pass | |
| def send_chat_request(access_token: str, messages: List[Dict[str, Any]], model: Optional[str] = None, stream: bool = False, timeout: Tuple[int,int] = (15,300)) -> Tuple[Optional[str], Optional[Generator[str, None, None]], bool]: | |
| url, headers_from_log, body_json = load_template() | |
| headers_from_log["amz-sdk-invocation-id"] = str(uuid.uuid4()) | |
| try: | |
| body_json["conversationState"]["conversationId"] = str(uuid.uuid4()) | |
| except Exception: | |
| pass | |
| history_text = openai_messages_to_text(messages) | |
| inject_history(body_json, history_text) | |
| inject_model(body_json, model) | |
| payload_str = json.dumps(body_json, ensure_ascii=False) | |
| headers = _merge_headers(headers_from_log, access_token) | |
| session = requests.Session() | |
| proxies = _get_proxies() | |
| resp = session.post(url, headers=headers, data=payload_str, stream=True, timeout=timeout, proxies=proxies) | |
| if resp.status_code >= 400: | |
| try: | |
| err = resp.text | |
| except Exception: | |
| err = f"HTTP {resp.status_code}" | |
| raise requests.HTTPError(f"Upstream error {resp.status_code}: {err}", response=resp) | |
| parser = AwsEventStreamParser() | |
| tracker = StreamTracker() | |
| def _iter_text() -> Generator[str, None, None]: | |
| for chunk in resp.iter_content(chunk_size=None): | |
| if not chunk: | |
| continue | |
| events = parser.feed(chunk) | |
| for _ev_headers, payload in events: | |
| parsed = _try_decode_event_payload(payload) | |
| if parsed is not None: | |
| text = _extract_text_from_event(parsed) | |
| if isinstance(text, str) and text: | |
| yield text | |
| else: | |
| try: | |
| txt = payload.decode("utf-8", errors="ignore") | |
| if txt: | |
| yield txt | |
| except Exception: | |
| pass | |
| if stream: | |
| return None, tracker.track(_iter_text()), tracker | |
| else: | |
| buf = [] | |
| for t in tracker.track(_iter_text()): | |
| buf.append(t) | |
| return "".join(buf), None, tracker |