| from transformers import AutoTokenizer
|
| from huggingface_hub import InferenceClient
|
| from langchain.prompts import PromptTemplate
|
| from typing import List, Dict
|
| import google.generativeai as genai
|
| import os
|
| from dotenv import load_dotenv
|
| load_dotenv()
|
|
|
| def safe_print(*args, **kwargs):
|
| from pdf_to_vector_store import safe_print as sp
|
| sp(*args, **kwargs)
|
|
|
| class LLMInterface:
|
| def __init__(self, hf_api_token: str = None, gemini_api_key: str = None, nebius_api_key: str = None, fireworks_api_key: str = None):
|
| """
|
| Initializes the LLM interface with Gemini, Fireworks AI, and Nebius API models.
|
|
|
| Args:
|
| hf_api_token: Hugging Face API token for tokenizer
|
| gemini_api_key: Google Gemini API key for serverless inference
|
| nebius_api_key: Nebius API key for LLaMA models
|
| fireworks_api_key: Fireworks AI API key for Mistral models
|
| """
|
|
|
| self.hf_api_token = hf_api_token or os.getenv("HF_TOKEN")
|
| if not self.hf_api_token:
|
| raise ValueError("HF_TOKEN environment variable not set. Set HF_TOKEN or provide a token.")
|
|
|
|
|
| self.gemini_api_key = gemini_api_key or os.getenv("GEMINI_API_KEY")
|
| if not self.gemini_api_key:
|
| raise ValueError("GEMINI_API_KEY environment variable not set. Set GEMINI_API_KEY or provide a key.")
|
| genai.configure(api_key=self.gemini_api_key)
|
|
|
|
|
| self.nebius_api_key = nebius_api_key or os.getenv("NEBIUS_API_KEY")
|
| if not self.nebius_api_key:
|
| raise ValueError("NEBIUS_API_KEY environment variable not set. Set NEBIUS_API_KEY or provide a key.")
|
|
|
|
|
| self.fireworks_api_key = fireworks_api_key or os.getenv("FIREWORKS_API_KEY")
|
| if not self.fireworks_api_key:
|
| raise ValueError("FIREWORKS_API_KEY environment variable not set. Set FIREWORKS_API_KEY or provide a key.")
|
|
|
|
|
| try:
|
| self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-70B-Instruct", token=self.hf_api_token)
|
| except Exception as e:
|
| safe_print(f"Error loading tokenizer: {e}")
|
| raise
|
|
|
| self.max_tokens = 4096
|
|
|
|
|
| self.models = [
|
| {
|
| "name": "google/gemini-1.5-flash",
|
| "display_name": "Gemini",
|
| "client": genai.GenerativeModel("gemini-1.5-flash"),
|
| "type": "gemini"
|
| },
|
| {
|
| "name": "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
| "display_name": "Mistral",
|
| "client": InferenceClient(provider="nebius", api_key=self.fireworks_api_key),
|
| "type": "nebius"
|
| },
|
| {
|
| "name": "meta-llama/Llama-3.1-70B-Instruct",
|
| "display_name": "Llama",
|
| "client": InferenceClient(provider="nebius", api_key=self.nebius_api_key),
|
| "type": "nebius"
|
| }
|
| ]
|
|
|
| self.prompt_template = PromptTemplate(
|
| input_variables=["context", "question"],
|
| template="You are a helpful assistant specializing in workout, diet, and gym recommendations. Answer questions related to these topics in a conversational style, using the provided context if relevant. For questions unrelated to workout, diet, or gym, do not provide a direct answer; instead, respond conversationally, steering the conversation back to fitness, diet, or gym topics without explicitly stating the question is out of scope.\n\nContext: {context}\n\nUser: {question}\n\nAssistant:"
|
| )
|
|
|
| def truncate_to_max_tokens(self, text: str) -> str:
|
| """
|
| Truncates text to fit within the maximum token limit.
|
| """
|
| tokens = self.tokenizer.encode(text, add_special_tokens=False)
|
| if len(tokens) > self.max_tokens:
|
| tokens = tokens[:self.max_tokens]
|
| text = self.tokenizer.decode(tokens, skip_special_tokens=True)
|
| return text |