Model_1 / app.py
Omnia-cy's picture
Create app.py
2b4174d verified
import gradio as gr
import torch
import torch.nn.functional as F
from transformers import BertTokenizer, BertForSequenceClassification, pipeline
# =========================
# Load BERT Model
# =========================
model_name = "Omnia-cy/bert_model_1"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name)
model.eval()
# =========================
# Load Zero-Shot Model
# =========================
zero_shot = pipeline(
"zero-shot-classification",
model="facebook/bart-large-mnli"
)
labels = [
"This problem can be solved using software or AI",
"This problem cannot be solved using software or AI"
]
# =========================
# Prediction Function
# =========================
def analyze_problem(sector, subsector, ptype, target, description):
# نفس الفورمات بتاع النوتبوك
text = (
f"Sector: {sector}. "
f"Subsector: {subsector}. "
f"Type: {ptype}. "
f"Target Group: {target}. "
f"Description: {description}."
)
# -------- BERT --------
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=256
)
with torch.no_grad():
outputs = model(**inputs)
probs = F.softmax(outputs.logits, dim=1)
ai_prob = probs[0][1].item()
# -------- Decision --------
threshold = 0.8
if ai_prob >= threshold:
return f"✅ AI Probability: {ai_prob:.2f}\nResult: Solvable using AI"
else:
# -------- Zero-Shot --------
result = zero_shot(text, candidate_labels=labels)
final_label = result["labels"][0]
return f"""
AI Probability: {ai_prob:.2f}
Zero-Shot Result:
{final_label}
"""
# =========================
# UI
# =========================
iface = gr.Interface(
fn=analyze_problem,
inputs=[
gr.Textbox(label="Sector"),
gr.Textbox(label="Subsector"),
gr.Textbox(label="Type"),
gr.Textbox(label="Target Group"),
gr.Textbox(label="Problem Description")
],
outputs=gr.Textbox(label="Result"),
title="AI Problem Analyzer",
description="Check if a problem can be solved using AI or software"
)
iface.launch()