TrendingBot / app.py
TimoTM's picture
Update app.py
dc6cb5d verified
import gradio as gr
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.llms.base import LLM
from transformers import pipeline
from pydantic import PrivateAttr
# Wrapper-Klasse für das deutsche GPT-2 Modell
class GermanGPT2(LLM):
_pipeline: any = PrivateAttr()
_max_new_tokens: int = PrivateAttr()
_temperature: float = PrivateAttr()
def __init__(self, max_new_tokens=128, temperature=0.7, **kwargs):
super().__init__(**kwargs)
self._max_new_tokens = max_new_tokens
self._temperature = temperature
self._pipeline = pipeline("text-generation", model="dbmdz/german-gpt2")
def _call(self, prompt, stop=None):
# Nutze nun max_new_tokens anstatt max_length
result = self._pipeline(prompt, max_new_tokens=self._max_new_tokens, do_sample=True, temperature=self._temperature)
return result[0]["generated_text"]
@property
def _identifying_params(self):
return {"model": "dbmdz/german-gpt2"}
@property
def _llm_type(self):
return "custom_german_gpt2"
# PDF wird beim Start automatisch geladen und verarbeitet
loader = PyPDFLoader("TrendingMedia_ChatbotBasis_FINAL.pdf")
documents = loader.load()
splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
texts = splitter.split_documents(documents)
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
db = FAISS.from_documents(texts, embeddings)
retriever = db.as_retriever(search_kwargs={"k": 2})
# Verwende den neuen GermanGPT2-Wrapper als LLM
llm = GermanGPT2(max_new_tokens=128, temperature=0.7)
qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever)
def ask_question(user_input):
if user_input.lower() in ["start", "hallo", "hi", "hey"]:
return "👋 Willkommen bei Trending Media! Wie kann ich dir behilflich sein?"
response = qa_chain.run(user_input)
if response.strip() == "" or "I'm sorry" in response or len(response.split()) < 5:
if "kontakt" in user_input.lower() or "erreichen" in user_input.lower():
return (
"📬 Du kannst uns direkt über dieses Formular kontaktieren:\n\n"
"**Vorname & Nachname:**\n[_________]\n\n"
"**E-Mail:**\n[_________]\n\n"
"**Nachricht:**\n[__________________________]\n\n"
"*Oder direkt über:* [📨 Kontaktformular](https://trendingmedia.ch/kontakt)"
)
else:
return "❌ Das kann ich dir leider nicht beantworten. Ich bin auf Informationen aus unserem PDF spezialisiert."
return response
with gr.Blocks() as demo:
gr.Markdown("## 🤖 TrendingBot\nWillkommen bei Trending Media! Stelle mir deine Frage.")
user_input = gr.Textbox(label="Deine Frage", placeholder="Frag mich etwas über unsere Lösungen...")
bot_response = gr.Textbox(label="TrendingBot antwortet")
user_input.submit(ask_question, inputs=user_input, outputs=bot_response)
demo.launch()