champ-chatbot / helpers /llm_helper.py
qyle's picture
pii filter improved
3da1373 verified
import os
import tiktoken
from champ.rag import (
create_embedding_model,
create_session_vector_store,
load_vector_store,
)
from champ.service import ChampService
from classes.base_models import ChatMessage
from constants import MODEL_MAP
from helpers.dynamodb_helper import log_environment_event
from helpers.message_helper import (
convert_messages,
convert_messages_langchain,
convert_messages_qwen,
)
from helpers.impacts_tracker_helper import (
get_openai_impacts,
get_champ_impacts,
get_qwen_impacts,
)
from opentelemetry import trace
from google import genai
from openai import AsyncOpenAI
from transformers import AutoTokenizer
from typing import Any, AsyncGenerator, Dict, List, Literal, Tuple
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if OPENAI_API_KEY is None:
raise RuntimeError(
"OPENAI_API_KEY is not set. "
"Go to Space → Settings → Variables & secrets and add one."
)
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
if GEMINI_API_KEY is None:
raise RuntimeError(
"GEMINI_API_KEY is not set. "
"Go to Space → Settings → Variables & secrets and add one."
)
openai_client = AsyncOpenAI(api_key=OPENAI_API_KEY) if OPENAI_API_KEY else None
gemini_client = genai.Client(api_key=GEMINI_API_KEY) if GEMINI_API_KEY else None
embedding_model = create_embedding_model()
base_vector_store = load_vector_store(embedding_model)
qwen_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-9B")
def _get_vector_store(document_contents: List[str] | None):
if document_contents is None:
vector_store = base_vector_store
else:
vector_store = create_session_vector_store(
base_vector_store, embedding_model, document_contents
)
return vector_store
async def _call_openai(
model_id: str, msgs: list[dict], document_texts: List[str] | None = None
) -> AsyncGenerator[str, None]:
final_reply = ""
output_token_count = 0
stream = await openai_client.responses.create(
model=model_id, input=msgs, stream=True
)
async for chunk in stream:
# The ecologits package does not work with the OpenAI client in streaming mode
# According to their documentation, it should, but, when experimenting, no output chunk had the
# "impacts" attribute.
if chunk.type == "response.output_text.delta":
final_reply += chunk.delta
yield chunk.delta
elif chunk.type == "response.completed":
# Final chunk contains usage metadata
# output_token_count = chunk.usage.completion_tokens
# The count below includes the reasoning tokens. Maybe we should disable reasoning.
output_token_count = chunk.response.usage.output_tokens
openai_impact = get_openai_impacts(output_token_count)
log_environment_event("inference", openai_impact, "openai")
gwp_avg_value = (
openai_impact.usage.gwp.value.min + openai_impact.usage.gwp.value.max # pyright: ignore[reportAttributeAccessIssue]
) / 2
yield f"\n###EMISSIONS:{gwp_avg_value}###"
yield f"\n###TOKEN_COUNT:{output_token_count}###"
# Passing the model id and the model type is weird, but whatever.
# The call_llm interface could be refactored so that each model shares a unified
# interface, but it is not a priority.
def _call_gemini(
model_id: str, msgs: list[dict], model_type: str
) -> tuple[str, float, int]:
transcript = []
for m in msgs:
role = m["role"]
content = m["content"]
transcript.append(f"{role.upper()}: {content}")
contents = "\n".join(transcript)
temperature = 0.2 if model_type == "google-conservative" else 1.0
if gemini_client is None:
raise ValueError("gemini_client is None")
resp = gemini_client.models.generate_content(
model=model_id,
contents=contents,
config={"temperature": temperature},
)
output_token_count = (
resp.usage_metadata.candidates_token_count
if resp.usage_metadata is not None
else 0
)
log_environment_event("inference", resp.impacts, model_type) # pyright: ignore[reportAttributeAccessIssue]
# Ecologits returns a range value for Gemini. We average it to get a value.
gwp_avg_value = (
resp.impacts.usage.gwp.value.min + resp.impacts.usage.gwp.value.max # pyright: ignore[reportAttributeAccessIssue]
) / 2
return (resp.text or "").strip(), gwp_avg_value, output_token_count or 0
def _call_champ(
lang: Literal["en", "fr"],
conversation: List[ChatMessage],
document_contents: List[str] | None,
prompt_template: str | None = None,
) -> tuple[str, float, dict[str, Any], list[str], int]:
tracer = trace.get_tracer(__name__)
vector_store = _get_vector_store(document_contents)
with tracer.start_as_current_span("ChampService"):
champ = ChampService(
vector_store=vector_store,
lang=lang,
model_type="champ",
prompt_template=prompt_template,
)
with tracer.start_as_current_span("convert_messages_langchain"):
msgs = convert_messages_langchain(conversation)
with tracer.start_as_current_span("invoke"):
reply, triage_meta, context, n_tokens = champ.invoke(msgs)
# LangChain is not comptatible with Ecologits. We approximate
# the environmental impact using the token output count.
encoding = tiktoken.get_encoding("o200k_harmony")
final_token_count = len(encoding.encode(reply))
champ_impacts = get_champ_impacts(final_token_count)
log_environment_event("inference", champ_impacts, "champ")
return (
reply,
champ_impacts.usage.gwp.value, # pyright: ignore[reportReturnType]
triage_meta,
context,
final_token_count,
)
def _call_qwen(
lang: Literal["en", "fr"],
conversation: List[ChatMessage],
document_contents: List[str] | None,
) -> tuple[str, float, dict[str, Any], list[str], int]:
vector_store = _get_vector_store(document_contents)
champ = ChampService(vector_store=vector_store, lang=lang, model_type="qwen")
msgs = convert_messages_qwen(conversation)
reply, triage_meta, context, n_tokens = champ.invoke(msgs)
# Ecologits doesn't work with Qwen, because the model is too recent.
# It might be added to the library eventually.
qwen_impacts = get_qwen_impacts(n_tokens)
log_environment_event("inference", qwen_impacts, "qwen")
return (
reply,
qwen_impacts.usage.gwp.value, # pyright: ignore[reportReturnType]
triage_meta,
context,
n_tokens,
)
def call_llm(
model_type: str,
lang: Literal["en", "fr"],
conversation: List[ChatMessage],
document_contents: List[str] | None,
) -> AsyncGenerator[str, None] | Tuple[str, float, Dict[str, Any], List[str], int]:
if model_type not in MODEL_MAP:
raise ValueError(f"Unknown model_type: {model_type}")
if model_type == "champ":
return _call_champ(lang, conversation, document_contents)
elif model_type == "qwen":
return _call_qwen(lang, conversation, document_contents)
model_id = MODEL_MAP[model_type]
msgs = convert_messages(conversation, lang=lang, docs_content=document_contents)
if model_type == "openai":
return _call_openai(model_id, msgs)
if model_type in ["google-conservative", "google-creative"]:
reply, gwp_emissions, output_token_count = _call_gemini(
model_id, msgs, model_type
)
return reply, gwp_emissions, {}, [], output_token_count
# If you later add HF models via hf_client, handle here.
raise ValueError(f"Unhandled model_type: {model_type}")