Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain.text_splitter import CharacterTextSplitter | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| from langchain.chains import RetrievalQA | |
| from langchain.llms.base import LLM | |
| from transformers import pipeline | |
| from pydantic import PrivateAttr | |
| # Wrapper-Klasse für das deutsche GPT-2 Modell | |
| class GermanGPT2(LLM): | |
| _pipeline: any = PrivateAttr() | |
| _max_new_tokens: int = PrivateAttr() | |
| _temperature: float = PrivateAttr() | |
| def __init__(self, max_new_tokens=128, temperature=0.7, **kwargs): | |
| super().__init__(**kwargs) | |
| self._max_new_tokens = max_new_tokens | |
| self._temperature = temperature | |
| self._pipeline = pipeline("text-generation", model="dbmdz/german-gpt2") | |
| def _call(self, prompt, stop=None): | |
| # Nutze nun max_new_tokens anstatt max_length | |
| result = self._pipeline(prompt, max_new_tokens=self._max_new_tokens, do_sample=True, temperature=self._temperature) | |
| return result[0]["generated_text"] | |
| def _identifying_params(self): | |
| return {"model": "dbmdz/german-gpt2"} | |
| def _llm_type(self): | |
| return "custom_german_gpt2" | |
| # PDF wird beim Start automatisch geladen und verarbeitet | |
| loader = PyPDFLoader("TrendingMedia_ChatbotBasis_FINAL.pdf") | |
| documents = loader.load() | |
| splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
| texts = splitter.split_documents(documents) | |
| embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| db = FAISS.from_documents(texts, embeddings) | |
| retriever = db.as_retriever(search_kwargs={"k": 2}) | |
| # Verwende den neuen GermanGPT2-Wrapper als LLM | |
| llm = GermanGPT2(max_new_tokens=128, temperature=0.7) | |
| qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever) | |
| def ask_question(user_input): | |
| if user_input.lower() in ["start", "hallo", "hi", "hey"]: | |
| return "👋 Willkommen bei Trending Media! Wie kann ich dir behilflich sein?" | |
| response = qa_chain.run(user_input) | |
| if response.strip() == "" or "I'm sorry" in response or len(response.split()) < 5: | |
| if "kontakt" in user_input.lower() or "erreichen" in user_input.lower(): | |
| return ( | |
| "📬 Du kannst uns direkt über dieses Formular kontaktieren:\n\n" | |
| "**Vorname & Nachname:**\n[_________]\n\n" | |
| "**E-Mail:**\n[_________]\n\n" | |
| "**Nachricht:**\n[__________________________]\n\n" | |
| "*Oder direkt über:* [📨 Kontaktformular](https://trendingmedia.ch/kontakt)" | |
| ) | |
| else: | |
| return "❌ Das kann ich dir leider nicht beantworten. Ich bin auf Informationen aus unserem PDF spezialisiert." | |
| return response | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## 🤖 TrendingBot\nWillkommen bei Trending Media! Stelle mir deine Frage.") | |
| user_input = gr.Textbox(label="Deine Frage", placeholder="Frag mich etwas über unsere Lösungen...") | |
| bot_response = gr.Textbox(label="TrendingBot antwortet") | |
| user_input.submit(ask_question, inputs=user_input, outputs=bot_response) | |
| demo.launch() | |