LLM_test / app.py
AssanaliAidarkhan's picture
Update app.py
fc5454c verified
# Ultra-Simple Qwen RAG - Guaranteed to work!
import gradio as gr
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Simple knowledge base
KNOWLEDGE = [
"Artificial Intelligence (AI) is intelligence demonstrated by machines. AI systems can perform tasks that require human intelligence like learning, reasoning, and problem-solving.",
"Machine Learning is a subset of AI that enables computers to learn from data without explicit programming. Main types include supervised, unsupervised, and reinforcement learning.",
"Deep Learning uses neural networks with multiple layers to process complex data. It's especially effective for image recognition, natural language processing, and speech recognition.",
"Natural Language Processing (NLP) helps computers understand human language. Applications include chatbots, translation, sentiment analysis, and question-answering systems.",
"Computer Vision enables computers to interpret visual information. Applications include facial recognition, medical imaging, autonomous vehicles, and security systems."
]
class SimpleQwenRAG:
def __init__(self):
logger.info("Loading models...")
# Load embedding model
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
# Load Qwen
self.tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen1.5-0.5B-Chat",
trust_remote_code=True
)
self.model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen1.5-0.5B-Chat",
torch_dtype=torch.float32,
trust_remote_code=True
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Create embeddings
self.embeddings = self.embedding_model.encode(KNOWLEDGE)
logger.info("βœ… Ready!")
def find_best_context(self, question):
"""Find most relevant knowledge"""
question_emb = self.embedding_model.encode([question])
similarities = np.dot(self.embeddings, question_emb.T).flatten()
best_idx = np.argmax(similarities)
return KNOWLEDGE[best_idx]
def generate_answer(self, question, context):
"""Generate answer with Qwen"""
# Very simple prompt
prompt = f"Question: {question}\nInformation: {context}\nAnswer:"
inputs = self.tokenizer(prompt, return_tensors="pt", max_length=300, truncation=True)
with torch.no_grad():
outputs = self.model.generate(
inputs.input_ids,
max_new_tokens=80,
temperature=0.5,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id
)
# Get only the new part
new_tokens = outputs[0][inputs.input_ids.shape[1]:]
answer = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
# If empty or too short, use context
if len(answer) < 5:
sentences = context.split('.')[:2]
answer = '. '.join(sentences) + '.'
return answer
def ask(self, question):
"""Main function"""
if not question.strip():
return "Please ask a question!"
# Find relevant context
context = self.find_best_context(question)
# Generate answer
answer = self.generate_answer(question, context)
return f"**Answer:** {answer}\n\n**Source:** {context}"
# Initialize
try:
rag = SimpleQwenRAG()
status = "βœ… Qwen RAG Ready!"
except Exception as e:
rag = None
status = f"❌ Failed: {e}"
def process_question(question):
if rag is None:
return "System not ready!"
return rag.ask(question)
# Simple UI
with gr.Blocks() as demo:
gr.Markdown("# πŸ€– Simple Qwen 1.5 RAG")
status_box = gr.Textbox(label="Status", value=status, interactive=False)
with gr.Row():
question_input = gr.Textbox(label="Question", placeholder="What is AI?")
ask_btn = gr.Button("Ask", variant="primary")
answer_output = gr.Textbox(label="Answer", lines=8, interactive=False)
# Sample questions
samples = ["What is AI?", "How does machine learning work?", "What is deep learning?"]
for sample in samples:
btn = gr.Button(sample)
btn.click(lambda x=sample: x, outputs=[question_input])
ask_btn.click(process_question, [question_input], [answer_output])
question_input.submit(process_question, [question_input], [answer_output])
demo.launch()