base-chatbot / app.py
kevalgajjar's picture
Update app.py
682cf73 verified
import gradio as gr
from huggingface_hub import InferenceClient
from PIL import Image, ImageEnhance, ImageFilter
import numpy as np
import easyocr
import glob
import re
import os
import base64
from io import BytesIO
import math
# Global storage
pdf_texts = {}
reader = None
blip_processor = None
blip_model = None
def load_pdfs():
global pdf_texts
pdf_texts.clear()
pdfs = glob.glob("*.pdf") + glob.glob("sources/*.pdf") + glob.glob("data/*.pdf")
for pdf_path in pdfs:
try:
import pdfplumber
with pdfplumber.open(pdf_path) as pdf:
text = "\n".join(page.extract_text() or "" for page in pdf.pages)
pdf_texts[pdf_path] = text
except:
pass
return f"✅ {len(pdf_texts)} PDFs" if pdf_texts else "No PDFs"
def initialize_vision_models():
global reader, blip_processor, blip_model
if reader is None:
reader = easyocr.Reader(['en'], gpu=False)
if blip_processor is None:
try:
from transformers import BlipProcessor, BlipForConditionalGeneration
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
except:
blip_processor = "error"
blip_model = "error"
def enhance_image_for_ocr(image):
if image.mode != 'RGB':
image = image.convert('RGB')
width, height = image.size
image = image.resize((width * 4, height * 4), Image.Resampling.LANCZOS)
image = image.convert('L')
enhancer = ImageEnhance.Contrast(image)
image = enhancer.enhance(2.0)
enhancer = ImageEnhance.Sharpness(image)
image = enhancer.enhance(2.0)
image = image.filter(ImageFilter.MedianFilter(size=3))
return image.convert('RGB')
def decode_base64_image(image_data):
if image_data is None:
return None
if isinstance(image_data, Image.Image):
return image_data
if isinstance(image_data, str) and image_data.startswith('data:image'):
try:
image_data = image_data.split(',')[1]
image_bytes = base64.b64decode(image_data)
return Image.open(BytesIO(image_bytes))
except:
pass
return image_data
def web_search(query):
try:
from ddgs import DDGS
results = DDGS().text(query, max_results=2)
return "\n".join([f"{r['title']}: {r['body'][:100]}" for r in results])
except:
return None
def analyze_image(image):
initialize_vision_models()
try:
enhanced = enhance_image_for_ocr(image)
image_np = np.array(enhanced)
ocr_results = reader.readtext(image_np, detail=1, paragraph=False, min_size=1, text_threshold=0.6, low_text=0.3)
texts = [text for _, text, conf in ocr_results if conf > 0.3]
full_text = ' '.join(texts)
if blip_processor != "error":
inputs = blip_processor(image, return_tensors="pt")
out = blip_model.generate(**inputs, max_length=50)
caption = blip_processor.decode(out[0], skip_special_tokens=True)
else:
caption = "Image"
return full_text, caption
except Exception as e:
return "", str(e)
def extract_math_calcs(text):
calcs = []
for match in re.finditer(r'C\s*\(\s*(\d+)\s*,\s*(\d+)\s*\)', text):
n, k = int(match.group(1)), int(match.group(2))
result = math.comb(n, k)
calcs.append(f"C({n},{k})={result:,}")
return calcs
def get_pdf_context(query):
if not pdf_texts:
return None
keywords = set(re.findall(r'\b\w{4,}\b', query.lower()))
chunks = []
for path, text in pdf_texts.items():
for sent in text.split('.')[:40]:
score = sum(1 for kw in keywords if kw in sent.lower())
if score > 0:
chunks.append((score, sent[:150]))
chunks.sort(reverse=True)
if chunks and chunks[0][0] >= 2:
return chunks[0][1]
return None
# MAIN API FUNCTION - Simple interface for external calls
def api_analyze(message: str, image_base64: str = None, model: str = "Qwen/Qwen2.5-7B-Instruct"):
"""
Simple API function for external calls
"""
token = os.getenv('HF_TOKEN')
# Decode image
img = None
if image_base64:
img = decode_base64_image(image_base64)
# Detect MCQ
has_options = bool(re.search(r'[A-D][\.\)]\s', message))
# Get context from image
context = ""
if img:
try:
ocr_text, _ = analyze_image(img)
if ocr_text:
context = f"\n\nExtracted text from image:\n{ocr_text[:400]}"
except Exception as e:
context = f"\n\n(Image processing error: {str(e)})"
# System message
if has_options:
sys_msg = "You are an exam assistant. For MCQ, give: Answer: [letter]. Reason: [one sentence only]."
temp = 0.2
tokens = 100
else:
sys_msg = "You are a helpful AI assistant."
temp = 0.6
tokens = 400
try:
client = InferenceClient(token=token, model=model)
except:
try:
client = InferenceClient(token=token, model="Qwen/Qwen2.5-7B-Instruct")
except Exception as e:
return f"Error connecting to model: {str(e)}"
messages = [
{"role": "system", "content": sys_msg},
{"role": "user", "content": message + context}
]
# Non-streaming response
response = ""
try:
for msg in client.chat_completion(
messages,
max_tokens=tokens,
stream=True,
temperature=temp,
top_p=0.9
):
if msg.choices and msg.choices[0].delta.content:
response += msg.choices[0].delta.content
# Stop early for MCQ
if has_options and len(response) > 200:
break
except Exception as e:
return f"Error during inference: {str(e)}"
return response.strip()
# Chat function for UI
def respond(
message,
history: list[dict[str, str]],
system_message,
max_tokens,
temperature,
top_p,
model_selection,
image,
hf_token,
):
"""UI chat function with streaming"""
if image is not None:
image = decode_base64_image(image)
token = os.getenv('HF_TOKEN') or (hf_token.strip() if hf_token else None)
has_options = bool(re.search(r'[A-D][\.\)]\s', message))
is_math_calc = any(w in message.lower() for w in ['calculate', 'factorial', 'combination'])
if is_math_calc and not has_options:
selected_model = "Qwen/Qwen2.5-Math-7B-Instruct"
else:
selected_model = model_selection
try:
client = InferenceClient(token=token, model=selected_model)
except:
try:
client = InferenceClient(token=token, model="Qwen/Qwen2.5-7B-Instruct")
except:
yield "❌ Cannot connect to model"
return
context = ""
if image is not None:
try:
ocr_text, _ = analyze_image(image)
if ocr_text:
context = f"\n\nImage text: {ocr_text[:400]}"
except:
pass
if has_options:
system_message = "Exam assistant. MCQ format: Answer: [letter]. Reason: [one sentence]."
temperature = 0.2
max_tokens = 100
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": message + context}
]
response = ""
try:
for msg in client.chat_completion(messages, max_tokens=max_tokens, stream=True, temperature=temperature, top_p=top_p):
if msg.choices and msg.choices[0].delta.content:
response += msg.choices[0].delta.content
yield response
except Exception as e:
yield f"Error: {str(e)}"
# Load PDFs
pdf_status = load_pdfs()
# Create TWO separate interfaces
# 1. Chat UI for users
chat_interface = gr.ChatInterface(
respond,
type="messages",
additional_inputs=[
gr.Textbox(value="You are a helpful AI assistant.", label="System message"),
gr.Slider(1, 800, 400, step=1, label="Max tokens"),
gr.Slider(0.1, 1.2, 0.6, step=0.1, label="Temperature"),
gr.Slider(0.1, 1.0, 0.9, step=0.05, label="Top-p"),
gr.Dropdown(
choices=["Qwen/Qwen2.5-7B-Instruct", "meta-llama/Llama-3.2-3B-Instruct", "HuggingFaceH4/zephyr-7b-beta","openai/gpt-oss-20b","Qwen/Qwen2.5-Math-7B-Instruct"],
value="Qwen/Qwen2.5-7B-Instruct",
label="Model",
),
gr.Image(label="Screenshot", type="pil"),
gr.Textbox(label="HF Token", type="password", value=""),
],
title="🤖 Exam Helper",
description=f"MCQ (short) • Math (steps) • General (detailed)\n\n{pdf_status}",
)
# 2. Simple API interface
api_interface = gr.Interface(
fn=api_analyze,
inputs=[
gr.Textbox(label="Message", placeholder="Enter your question"),
gr.Textbox(label="Image (base64)", placeholder="Optional base64 image"),
gr.Textbox(label="Model", value="Qwen/Qwen2.5-7B-Instruct"),
],
outputs=gr.Textbox(label="Response"),
title="API Endpoint",
description="Direct API access",
api_name="analyze" # Creates /call/analyze endpoint
)
# Combine both in tabs
demo = gr.TabbedInterface(
[chat_interface, api_interface],
["Chat", "API"],
title="🤖 AI Assistant"
)
if __name__ == "__main__":
demo.launch()