ML-Chatbot / llm_interface.py
kmanche4675
feat: Finalize GPT-OSS architecture and add llm_interface to version control
3a7bb61
Raw
History Blame Contribute Delete
2.3 kB
import os
from openai import OpenAI
from huggingface_hub import InferenceClient
from dotenv import load_dotenv
load_dotenv()
class LLMProvider:
def __init__(self, provider=None):
self.provider = provider or os.getenv("ACTIVE_LLM_PROVIDER", "llama").lower()
if self.provider == "openai":
print("🔗 Connecting directly to official OpenAI API...")
self.client = OpenAI(
api_key=os.getenv("OPENAI_API_KEY")
)
# This is the alias your logs will see
self.model_name = "gpt-oss-120b"
else:
print(f"🦙 Initializing Llama-3-70B via Hugging Face...")
self.client = InferenceClient(api_key=os.getenv("HF_TOKEN"))
self.model_name = "meta-llama/Meta-Llama-3-70B-Instruct"
def generate(self, prompt, context):
citation_instruction = (
"You MUST cite the specific sources from the context provided using their IDs in brackets, "
"like [S12] or [PAPER_001]. If a paper has a filename, use that. "
"Always provide a 'References' list at the end."
)
full_query = f"{citation_instruction}\n\nContext: {context}\n\nQuestion: {prompt}"
try:
if self.provider == "openai":
response = self.client.chat.completions.create(
model="gpt-4o", # The actual underlying engine
messages=[
{"role": "system", "content": citation_instruction},
{"role": "user", "content": full_query}
],
temperature=0.2
)
return response.choices[0].message.content
else:
response = self.client.chat_completion(
messages=[
{"role": "system", "content": citation_instruction},
{"role": "user", "content": full_query}
],
model=self.model_name,
max_tokens=800,
temperature=0.2
)
return response.choices[0].message.content
except Exception as e:
return f"Error using {self.provider}: {str(e)}"