sivan26 commited on
Commit
65f4c3e
·
verified ·
1 Parent(s): c211272

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -8
app.py CHANGED
@@ -3,16 +3,15 @@ import requests
3
  import random
4
  from transformers import pipeline
5
 
6
- # Load the zero-shot classification model
7
  classifier = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-3")
8
 
9
- # Simplified general labels for better accuracy
10
  labels = [
11
  "animals", "people", "places", "history", "science", "art", "technology",
12
  "sports", "food", "clothing", "home", "entertainment", "education", "nature", "transportation"
13
  ]
14
 
15
- # Preprocessing for better classification context
16
  def preprocess_topic(topic):
17
  topic = topic.lower().strip()
18
  mapping = {
@@ -35,13 +34,12 @@ def preprocess_topic(topic):
35
  }
36
  return mapping.get(topic, topic)
37
 
38
- # Random topics list
39
  random_topics = [
40
  "cats", "space", "chocolate", "Egypt", "Leonardo da Vinci",
41
  "volcanoes", "Tokyo", "honeybees", "quantum physics", "orcas"
42
  ]
43
 
44
- # Wikipedia + classification function
45
  def get_wikipedia_facts(topic):
46
  if not topic.strip():
47
  return "Please enter a topic or use 'Surprise me!'", None, None
@@ -80,7 +78,7 @@ def get_wikipedia_facts(topic):
80
  facts = [fact if fact.endswith(".") else fact + "." for fact in facts]
81
  facts_text = "\n\n".join(f"💡 {fact}" for fact in facts)
82
 
83
- # Zero-shot classification with general labels
84
  processed_input = preprocess_topic(topic)
85
  classification = classifier(processed_input, candidate_labels=labels)
86
  top_labels = classification["labels"][:3]
@@ -98,12 +96,12 @@ def get_wikipedia_facts(topic):
98
  print("Error:", e)
99
  return "Oops! Something went wrong while fetching your facts.", None, None
100
 
101
- # Surprise me button
102
  def surprise_topic(_):
103
  topic = random.choice(random_topics)
104
  return get_wikipedia_facts(topic)
105
 
106
- # Gradio UI
107
  with gr.Blocks() as demo:
108
  gr.HTML("""
109
  <style>
 
3
  import random
4
  from transformers import pipeline
5
 
6
+
7
  classifier = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-3")
8
 
 
9
  labels = [
10
  "animals", "people", "places", "history", "science", "art", "technology",
11
  "sports", "food", "clothing", "home", "entertainment", "education", "nature", "transportation"
12
  ]
13
 
14
+
15
  def preprocess_topic(topic):
16
  topic = topic.lower().strip()
17
  mapping = {
 
34
  }
35
  return mapping.get(topic, topic)
36
 
 
37
  random_topics = [
38
  "cats", "space", "chocolate", "Egypt", "Leonardo da Vinci",
39
  "volcanoes", "Tokyo", "honeybees", "quantum physics", "orcas"
40
  ]
41
 
42
+
43
  def get_wikipedia_facts(topic):
44
  if not topic.strip():
45
  return "Please enter a topic or use 'Surprise me!'", None, None
 
78
  facts = [fact if fact.endswith(".") else fact + "." for fact in facts]
79
  facts_text = "\n\n".join(f"💡 {fact}" for fact in facts)
80
 
81
+
82
  processed_input = preprocess_topic(topic)
83
  classification = classifier(processed_input, candidate_labels=labels)
84
  top_labels = classification["labels"][:3]
 
96
  print("Error:", e)
97
  return "Oops! Something went wrong while fetching your facts.", None, None
98
 
99
+
100
  def surprise_topic(_):
101
  topic = random.choice(random_topics)
102
  return get_wikipedia_facts(topic)
103
 
104
+
105
  with gr.Blocks() as demo:
106
  gr.HTML("""
107
  <style>