Spaces:
Paused
Paused
| import os | |
| import tiktoken | |
| from champ.rag import ( | |
| create_embedding_model, | |
| create_session_vector_store, | |
| load_vector_store, | |
| ) | |
| from champ.service import ChampService | |
| from classes.base_models import ChatMessage | |
| from constants import MODEL_MAP | |
| from helpers.dynamodb_helper import log_environment_event | |
| from helpers.message_helper import ( | |
| convert_messages, | |
| convert_messages_langchain, | |
| convert_messages_qwen, | |
| ) | |
| from helpers.impacts_tracker_helper import ( | |
| get_openai_impacts, | |
| get_champ_impacts, | |
| get_qwen_impacts, | |
| ) | |
| from opentelemetry import trace | |
| from google import genai | |
| from openai import AsyncOpenAI | |
| from transformers import AutoTokenizer | |
| from typing import Any, AsyncGenerator, Dict, List, Literal, Tuple | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| if OPENAI_API_KEY is None: | |
| raise RuntimeError( | |
| "OPENAI_API_KEY is not set. " | |
| "Go to Space → Settings → Variables & secrets and add one." | |
| ) | |
| GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") | |
| if GEMINI_API_KEY is None: | |
| raise RuntimeError( | |
| "GEMINI_API_KEY is not set. " | |
| "Go to Space → Settings → Variables & secrets and add one." | |
| ) | |
| openai_client = AsyncOpenAI(api_key=OPENAI_API_KEY) if OPENAI_API_KEY else None | |
| gemini_client = genai.Client(api_key=GEMINI_API_KEY) if GEMINI_API_KEY else None | |
| embedding_model = create_embedding_model() | |
| base_vector_store = load_vector_store(embedding_model) | |
| qwen_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-9B") | |
| def _get_vector_store(document_contents: List[str] | None): | |
| if document_contents is None: | |
| vector_store = base_vector_store | |
| else: | |
| vector_store = create_session_vector_store( | |
| base_vector_store, embedding_model, document_contents | |
| ) | |
| return vector_store | |
| async def _call_openai( | |
| model_id: str, msgs: list[dict], document_texts: List[str] | None = None | |
| ) -> AsyncGenerator[str, None]: | |
| final_reply = "" | |
| output_token_count = 0 | |
| stream = await openai_client.responses.create( | |
| model=model_id, input=msgs, stream=True | |
| ) | |
| async for chunk in stream: | |
| # The ecologits package does not work with the OpenAI client in streaming mode | |
| # According to their documentation, it should, but, when experimenting, no output chunk had the | |
| # "impacts" attribute. | |
| if chunk.type == "response.output_text.delta": | |
| final_reply += chunk.delta | |
| yield chunk.delta | |
| elif chunk.type == "response.completed": | |
| # Final chunk contains usage metadata | |
| # output_token_count = chunk.usage.completion_tokens | |
| # The count below includes the reasoning tokens. Maybe we should disable reasoning. | |
| output_token_count = chunk.response.usage.output_tokens | |
| openai_impact = get_openai_impacts(output_token_count) | |
| log_environment_event("inference", openai_impact, "openai") | |
| gwp_avg_value = ( | |
| openai_impact.usage.gwp.value.min + openai_impact.usage.gwp.value.max # pyright: ignore[reportAttributeAccessIssue] | |
| ) / 2 | |
| yield f"\n###EMISSIONS:{gwp_avg_value}###" | |
| yield f"\n###TOKEN_COUNT:{output_token_count}###" | |
| # Passing the model id and the model type is weird, but whatever. | |
| # The call_llm interface could be refactored so that each model shares a unified | |
| # interface, but it is not a priority. | |
| def _call_gemini( | |
| model_id: str, msgs: list[dict], model_type: str | |
| ) -> tuple[str, float, int]: | |
| transcript = [] | |
| for m in msgs: | |
| role = m["role"] | |
| content = m["content"] | |
| transcript.append(f"{role.upper()}: {content}") | |
| contents = "\n".join(transcript) | |
| temperature = 0.2 if model_type == "google-conservative" else 1.0 | |
| if gemini_client is None: | |
| raise ValueError("gemini_client is None") | |
| resp = gemini_client.models.generate_content( | |
| model=model_id, | |
| contents=contents, | |
| config={"temperature": temperature}, | |
| ) | |
| output_token_count = ( | |
| resp.usage_metadata.candidates_token_count | |
| if resp.usage_metadata is not None | |
| else 0 | |
| ) | |
| log_environment_event("inference", resp.impacts, model_type) # pyright: ignore[reportAttributeAccessIssue] | |
| # Ecologits returns a range value for Gemini. We average it to get a value. | |
| gwp_avg_value = ( | |
| resp.impacts.usage.gwp.value.min + resp.impacts.usage.gwp.value.max # pyright: ignore[reportAttributeAccessIssue] | |
| ) / 2 | |
| return (resp.text or "").strip(), gwp_avg_value, output_token_count or 0 | |
| def _call_champ( | |
| lang: Literal["en", "fr"], | |
| conversation: List[ChatMessage], | |
| document_contents: List[str] | None, | |
| prompt_template: str | None = None, | |
| ) -> tuple[str, float, dict[str, Any], list[str], int]: | |
| tracer = trace.get_tracer(__name__) | |
| vector_store = _get_vector_store(document_contents) | |
| with tracer.start_as_current_span("ChampService"): | |
| champ = ChampService( | |
| vector_store=vector_store, | |
| lang=lang, | |
| model_type="champ", | |
| prompt_template=prompt_template, | |
| ) | |
| with tracer.start_as_current_span("convert_messages_langchain"): | |
| msgs = convert_messages_langchain(conversation) | |
| with tracer.start_as_current_span("invoke"): | |
| reply, triage_meta, context, n_tokens = champ.invoke(msgs) | |
| # LangChain is not comptatible with Ecologits. We approximate | |
| # the environmental impact using the token output count. | |
| encoding = tiktoken.get_encoding("o200k_harmony") | |
| final_token_count = len(encoding.encode(reply)) | |
| champ_impacts = get_champ_impacts(final_token_count) | |
| log_environment_event("inference", champ_impacts, "champ") | |
| return ( | |
| reply, | |
| champ_impacts.usage.gwp.value, # pyright: ignore[reportReturnType] | |
| triage_meta, | |
| context, | |
| final_token_count, | |
| ) | |
| def _call_qwen( | |
| lang: Literal["en", "fr"], | |
| conversation: List[ChatMessage], | |
| document_contents: List[str] | None, | |
| ) -> tuple[str, float, dict[str, Any], list[str], int]: | |
| vector_store = _get_vector_store(document_contents) | |
| champ = ChampService(vector_store=vector_store, lang=lang, model_type="qwen") | |
| msgs = convert_messages_qwen(conversation) | |
| reply, triage_meta, context, n_tokens = champ.invoke(msgs) | |
| # Ecologits doesn't work with Qwen, because the model is too recent. | |
| # It might be added to the library eventually. | |
| qwen_impacts = get_qwen_impacts(n_tokens) | |
| log_environment_event("inference", qwen_impacts, "qwen") | |
| return ( | |
| reply, | |
| qwen_impacts.usage.gwp.value, # pyright: ignore[reportReturnType] | |
| triage_meta, | |
| context, | |
| n_tokens, | |
| ) | |
| def call_llm( | |
| model_type: str, | |
| lang: Literal["en", "fr"], | |
| conversation: List[ChatMessage], | |
| document_contents: List[str] | None, | |
| ) -> AsyncGenerator[str, None] | Tuple[str, float, Dict[str, Any], List[str], int]: | |
| if model_type not in MODEL_MAP: | |
| raise ValueError(f"Unknown model_type: {model_type}") | |
| if model_type == "champ": | |
| return _call_champ(lang, conversation, document_contents) | |
| elif model_type == "qwen": | |
| return _call_qwen(lang, conversation, document_contents) | |
| model_id = MODEL_MAP[model_type] | |
| msgs = convert_messages(conversation, lang=lang, docs_content=document_contents) | |
| if model_type == "openai": | |
| return _call_openai(model_id, msgs) | |
| if model_type in ["google-conservative", "google-creative"]: | |
| reply, gwp_emissions, output_token_count = _call_gemini( | |
| model_id, msgs, model_type | |
| ) | |
| return reply, gwp_emissions, {}, [], output_token_count | |
| # If you later add HF models via hf_client, handle here. | |
| raise ValueError(f"Unhandled model_type: {model_type}") | |