Spaces:
Sleeping
Sleeping
File size: 3,245 Bytes
af2f254 46d395e af2f254 bce48eb 46d395e af2f254 bce48eb 46d395e 057791c 46d395e 38954a9 2129939 057791c 2129939 057791c 46d395e dc6cb5d 46d395e af2f254 46d395e 38954a9 bce48eb 38954a9 46d395e 057791c af2f254 46d395e 38954a9 d52a866 af2f254 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import gradio as gr
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.llms.base import LLM
from transformers import pipeline
from pydantic import PrivateAttr
# Wrapper-Klasse für das deutsche GPT-2 Modell
class GermanGPT2(LLM):
_pipeline: any = PrivateAttr()
_max_new_tokens: int = PrivateAttr()
_temperature: float = PrivateAttr()
def __init__(self, max_new_tokens=128, temperature=0.7, **kwargs):
super().__init__(**kwargs)
self._max_new_tokens = max_new_tokens
self._temperature = temperature
self._pipeline = pipeline("text-generation", model="dbmdz/german-gpt2")
def _call(self, prompt, stop=None):
# Nutze nun max_new_tokens anstatt max_length
result = self._pipeline(prompt, max_new_tokens=self._max_new_tokens, do_sample=True, temperature=self._temperature)
return result[0]["generated_text"]
@property
def _identifying_params(self):
return {"model": "dbmdz/german-gpt2"}
@property
def _llm_type(self):
return "custom_german_gpt2"
# PDF wird beim Start automatisch geladen und verarbeitet
loader = PyPDFLoader("TrendingMedia_ChatbotBasis_FINAL.pdf")
documents = loader.load()
splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
texts = splitter.split_documents(documents)
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
db = FAISS.from_documents(texts, embeddings)
retriever = db.as_retriever(search_kwargs={"k": 2})
# Verwende den neuen GermanGPT2-Wrapper als LLM
llm = GermanGPT2(max_new_tokens=128, temperature=0.7)
qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever)
def ask_question(user_input):
if user_input.lower() in ["start", "hallo", "hi", "hey"]:
return "👋 Willkommen bei Trending Media! Wie kann ich dir behilflich sein?"
response = qa_chain.run(user_input)
if response.strip() == "" or "I'm sorry" in response or len(response.split()) < 5:
if "kontakt" in user_input.lower() or "erreichen" in user_input.lower():
return (
"📬 Du kannst uns direkt über dieses Formular kontaktieren:\n\n"
"**Vorname & Nachname:**\n[_________]\n\n"
"**E-Mail:**\n[_________]\n\n"
"**Nachricht:**\n[__________________________]\n\n"
"*Oder direkt über:* [📨 Kontaktformular](https://trendingmedia.ch/kontakt)"
)
else:
return "❌ Das kann ich dir leider nicht beantworten. Ich bin auf Informationen aus unserem PDF spezialisiert."
return response
with gr.Blocks() as demo:
gr.Markdown("## 🤖 TrendingBot\nWillkommen bei Trending Media! Stelle mir deine Frage.")
user_input = gr.Textbox(label="Deine Frage", placeholder="Frag mich etwas über unsere Lösungen...")
bot_response = gr.Textbox(label="TrendingBot antwortet")
user_input.submit(ask_question, inputs=user_input, outputs=bot_response)
demo.launch()
|