NHZ commited on
Commit
c549079
·
verified ·
1 Parent(s): 05b86d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -1
app.py CHANGED
@@ -8,9 +8,39 @@ from langchain.vectorstores import FAISS
8
  from langchain.embeddings import HuggingFaceEmbeddings
9
  from langchain.chains import RetrievalQA
10
  from langchain.prompts import PromptTemplate
11
- from langchain.llms import GroqLLM
 
12
  import streamlit as st
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  # Initialize Groq API LLM
15
  llm = GroqLLM(api_key=os.getenv("GROQ_API_KEY"))
16
 
 
8
  from langchain.embeddings import HuggingFaceEmbeddings
9
  from langchain.chains import RetrievalQA
10
  from langchain.prompts import PromptTemplate
11
+ from langchain.llms.base import LLM
12
+ from typing import Optional, List, Mapping, Any
13
  import streamlit as st
14
 
15
+ # Custom wrapper for Groq API
16
+ class GroqLLM(LLM):
17
+ def __init__(self, api_key: str, model: str = "llama-3.3-70b-versatile"):
18
+ self.api_key = api_key
19
+ self.model = model
20
+
21
+ @property
22
+ def _llm_type(self) -> str:
23
+ return "groq"
24
+
25
+ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
26
+ import requests
27
+
28
+ headers = {"Authorization": f"Bearer {self.api_key}"}
29
+ json_data = {
30
+ "model": self.model,
31
+ "messages": [{"role": "user", "content": prompt}],
32
+ }
33
+
34
+ response = requests.post(
35
+ "https://api.groq.com/v1/chat/completions", headers=headers, json=json_data
36
+ )
37
+
38
+ if response.status_code != 200:
39
+ raise ValueError(f"Groq API call failed: {response.status_code}, {response.text}")
40
+
41
+ data = response.json()
42
+ return data["choices"][0]["message"]["content"]
43
+
44
  # Initialize Groq API LLM
45
  llm = GroqLLM(api_key=os.getenv("GROQ_API_KEY"))
46