TimoTM commited on
Commit
057791c
·
verified ·
1 Parent(s): 38954a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -6,16 +6,20 @@ from langchain_community.vectorstores import FAISS
6
  from langchain.chains import RetrievalQA
7
  from langchain.llms.base import LLM
8
  from transformers import pipeline
 
9
 
10
  # Wrapper-Klasse für das deutsche GPT-2 Modell
11
  class GermanGPT2(LLM):
12
- def __init__(self, max_new_tokens=128, temperature=0.7):
13
- self.pipeline = pipeline("text-generation", model="dbmdz/german-gpt2")
 
 
14
  self.max_new_tokens = max_new_tokens
15
  self.temperature = temperature
 
16
 
17
  def _call(self, prompt, stop=None):
18
- result = self.pipeline(prompt, max_length=self.max_new_tokens, do_sample=True, temperature=self.temperature)
19
  return result[0]["generated_text"]
20
 
21
  @property
@@ -26,7 +30,7 @@ class GermanGPT2(LLM):
26
  def _llm_type(self):
27
  return "custom_german_gpt2"
28
 
29
- # Lade und verarbeite das PDF beim Start
30
  loader = PyPDFLoader("TrendingMedia_ChatbotBasis_FINAL.pdf")
31
  documents = loader.load()
32
  splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
 
6
  from langchain.chains import RetrievalQA
7
  from langchain.llms.base import LLM
8
  from transformers import pipeline
9
+ from pydantic import PrivateAttr
10
 
11
  # Wrapper-Klasse für das deutsche GPT-2 Modell
12
  class GermanGPT2(LLM):
13
+ _pipeline: any = PrivateAttr() # privates Attribut, um die Pipeline zu speichern
14
+
15
+ def __init__(self, max_new_tokens=128, temperature=0.7, **kwargs):
16
+ super().__init__(**kwargs)
17
  self.max_new_tokens = max_new_tokens
18
  self.temperature = temperature
19
+ self._pipeline = pipeline("text-generation", model="dbmdz/german-gpt2")
20
 
21
  def _call(self, prompt, stop=None):
22
+ result = self._pipeline(prompt, max_length=self.max_new_tokens, do_sample=True, temperature=self.temperature)
23
  return result[0]["generated_text"]
24
 
25
  @property
 
30
  def _llm_type(self):
31
  return "custom_german_gpt2"
32
 
33
+ # PDF wird beim Start automatisch geladen und verarbeitet
34
  loader = PyPDFLoader("TrendingMedia_ChatbotBasis_FINAL.pdf")
35
  documents = loader.load()
36
  splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)