STASS / app.py
Harikrishna-Srinivasan's picture
Update app.py
554da1f verified
raw
history blame
7.72 kB
from docx import Document
import pytesseract
from PIL import Image
import fitz
import gradio as gr
import threading
import pathlib
import os
# --------------------------------------------------
# TOKEN RESOLUTION
# --------------------------------------------------
def resolve_token(ui_token):
if ui_token.strip():
return ui_token.strip()
env_token = os.getenv("hgface_tok")
if env_token:
return env_token.strip()
return ""
# --------------------------------------------------
# FILE TEXT EXTRACTION
# --------------------------------------------------
SUPPORTED_EXT = (
".pdf",".docx",".txt",".png",".jpg",".jpeg",".webp",".bmp",".tiff"
)
def extract_text_from_file(filepath):
if not filepath:
return ""
if hasattr(filepath,"name"):
filepath = filepath.name
ext = pathlib.Path(filepath).suffix.lower()
try:
if ext == ".pdf":
doc = fitz.open(filepath)
text = []
for page in doc:
text.append(page.get_text())
return "\n".join(text)
elif ext == ".docx":
doc = Document(filepath)
return "\n".join(p.text for p in doc.paragraphs)
elif ext == ".txt":
with open(filepath,"r",encoding="utf-8",errors="ignore") as f:
return f.read()
elif ext in (".png",".jpg",".jpeg",".webp",".bmp",".tiff"):
try:
img = Image.open(filepath)
return pytesseract.image_to_string(img)
except Exception as e:
return "OCR failed: " + str(e)
else:
return "Unsupported file type: " + ext
except Exception as e:
return "Could not read file: " + str(e)
# --------------------------------------------------
# MODELS
# --------------------------------------------------
MODELS = {
"Gemma 3 270M [0.6GB | Lightning-fast Edge]": "google/gemma-3-270m-it",
"Qwen 3 0.6B GGUF [0.5GB | Classroom Assistant]": "Qwen/Qwen3-0.6B-GGUF",
"TinyLlama 1.1B [0.5GB]": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"Qwen 3.5 2B [2.4GB | The Student Tutor]": "Qwen/Qwen3.5-2B",
"Phi-4 Mini [1.8GB | Logical Powerhouse]": "microsoft/Phi-4-mini-instruct",
"Gemma 3 1B [2.1GB | Stable & Coherent]": "google/gemma-3-1b-it",
"Qwen 3.5 9B [7.8GB | BEST FOR LESSON PLANS]": "Qwen/Qwen3.5-9B",
"Llama 3.1 8B [5.2GB | Industry Standard]": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"Mistral Small 3 [7.1GB | Concise & Accurate]": "mistralai/Mistral-Small-3-Instruct",
"Gemma 3 9B [6.3GB | Creative & Safe]": "google/gemma-3-9b-it",
"Mistral Small 12B [9.5GB | Perfect VRAM Balance]": "mistralai/Mistral-Nemo-Instruct-2407",
"Qwen 3.5 27B [18GB | Dense Curriculum Architect]": "Qwen/Qwen3.5-27B",
}
ALL_MODEL_NAMES = list(MODELS.keys())
# --------------------------------------------------
# PIPELINE CACHE
# --------------------------------------------------
_pipeline_cache = {}
_pipeline_lock = threading.Lock()
def get_pipeline(model_id, hf_token):
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
with _pipeline_lock:
if model_id not in _pipeline_cache:
try:
kwargs = {
"trust_remote_code": True
}
if hf_token:
kwargs["token"] = hf_token
tokenizer = AutoTokenizer.from_pretrained(
model_id,
**kwargs
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="cpu",
**kwargs
)
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer
)
_pipeline_cache[model_id] = pipe
except Exception as e:
return None, str(e)
return _pipeline_cache[model_id], None
# --------------------------------------------------
# INFERENCE
# --------------------------------------------------
SYSTEM_MSG = "You are an expert educational assistant. Use markdown."
def ask_llm(model_label, prompt, hf_token=""):
token = resolve_token(hf_token)
model_id = MODELS[model_label]
pipe, err = get_pipeline(model_id, token)
if err:
return "Model load error:\n" + err
try:
combined = SYSTEM_MSG + "\n\n" + prompt
out = pipe(
combined,
max_new_tokens=2048,
do_sample=True,
temperature=0.6,
top_p=0.9
)
text = out[0]["generated_text"]
if text.startswith(combined):
text = text[len(combined):]
return text.strip()
except Exception as e:
return "Inference error:\n" + str(e)
# --------------------------------------------------
# PROMPTS
# --------------------------------------------------
def make_prompts(topic):
return {
"lesson":
"Create a lesson plan with headings and bullet points.\n\nTopic:\n"+topic,
"qa":
"Generate 10 exam questions with answers.\n\nTopic:\n"+topic,
"mcq":
"Generate 10 MCQs with 4 options and answers.\n\nTopic:\n"+topic,
"summary":
"Summarize the topic in 250-300 words.\n\nTopic:\n"+topic,
"infographic":
"Create a cheat sheet using tables and bullet points.\n\nTopic:\n"+topic
}
def generate_content(text, file, model_label, token):
file_text = extract_text_from_file(file) if file else ""
syllabus = (text + "\n\n" + file_text).strip()
if not syllabus:
yield ("Provide topic or file","","","","")
return
prompts = make_prompts(syllabus)
WAIT = "Generating..."
results = [WAIT,WAIT,WAIT,WAIT,WAIT]
yield tuple(results)
order = ["lesson","qa","mcq","summary","infographic"]
for i, key in enumerate(order):
res = ask_llm(model_label, prompts[key], token)
results[i] = res
yield tuple(results)
# --------------------------------------------------
# UI
# --------------------------------------------------
CSS = """
body,.gradio-container{
font-family:Inter,sans-serif!important;
}
"""
with gr.Blocks() as demo:
gr.Markdown("# 🎓 AI Study Material Generator")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
placeholder="Paste syllabus or topic",
lines=6
)
file_input = gr.File(
label="Upload syllabus file"
)
with gr.Column():
model_selector = gr.Dropdown(
choices=ALL_MODEL_NAMES,
value=ALL_MODEL_NAMES[0],
label="Model"
)
token_box = gr.Textbox(
label="HF Token (optional)",
type="password"
)
btn = gr.Button("Generate")
with gr.Tabs():
with gr.TabItem("Lesson"):
lesson = gr.Markdown()
with gr.TabItem("Q&A"):
qa = gr.Markdown()
with gr.TabItem("MCQ"):
mcq = gr.Markdown()
with gr.TabItem("Summary"):
summary = gr.Markdown()
with gr.TabItem("Cheat Sheet"):
cheat = gr.Markdown()
btn.click(
fn=generate_content,
inputs=[text_input,file_input,model_selector,token_box],
outputs=[lesson,qa,mcq,summary,cheat]
)
demo.launch(
theme=gr.themes.Soft(
primary_hue="indigo",
secondary_hue="purple"
),
css=CSS
)