Minte commited on
Commit
07e06da
·
1 Parent(s): ad4b7e9

solve the problem

Browse files
Files changed (1) hide show
  1. app.py +64 -131
app.py CHANGED
@@ -1,153 +1,86 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
- from fastapi import FastAPI, Request
5
 
6
- # -------------------------------------------------
7
- # 1. Load model (same as your old code)
8
- # -------------------------------------------------
9
- print("Initializing DialoGPT-medium model...")
10
- model_name = "microsoft/DialoGPT-medium"
11
- tokenizer = AutoTokenizer.from_pretrained(model_name)
12
- model = AutoModelForCausalLM.from_pretrained(model_name)
13
-
14
- if tokenizer.pad_token is None:
15
- tokenizer.pad_token = tokenizer.eos_token
16
-
17
- print("DialoGPT-medium loaded!")
18
-
19
- # -------------------------------------------------
20
- # 2. Generation helper (your old logic, cleaned up)
21
- # -------------------------------------------------
22
- def generate_response(message: str, chat_history: list):
23
- if not message.strip():
24
- return "Please enter a message."
25
-
26
- # Build conversation string
27
- conv = ""
28
- for user, bot in chat_history:
29
- conv += f"User: {user}\nBot: {bot}\n"
30
- conv += f"User: {message}\nBot:"
31
-
32
- # Encode
33
- inputs = tokenizer.encode(conv, return_tensors="pt", max_length=1024, truncation=True)
34
-
35
- # Generate
 
 
 
 
 
 
 
36
  with torch.no_grad():
37
  outputs = model.generate(
38
  inputs,
39
- max_length=inputs.shape[1] + 128,
40
  pad_token_id=tokenizer.eos_token_id,
41
  do_sample=True,
42
  temperature=0.7,
43
  top_k=50,
44
  top_p=0.95,
45
- repetition_penalty=1.2,
46
  )
47
-
48
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
49
  response = response.split("Bot:")[-1].strip()
 
 
50
  if "\nUser:" in response:
51
  response = response.split("\nUser:")[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- return response
54
-
55
- # -------------------------------------------------
56
- # 3. Gradio chat function (used by /run/predict)
57
- # -------------------------------------------------
58
- def chat_fn(message: str, history: list):
59
- response = generate_response(message, history or [])
60
- history.append((message, response))
61
- return "", history # clear textbox, update chat
62
-
63
- # -------------------------------------------------
64
- # 4. Build the UI (your Blocks layout)
65
- # -------------------------------------------------
66
- example_questions = [
67
- "Hello! How are you today?",
68
- "What can you help me with?",
69
- "Tell me about artificial intelligence",
70
- "What's your favorite programming language?",
71
- "Can you explain machine learning?",
72
- "How does a neural network work?"
73
- ]
74
-
75
- with gr.Blocks(
76
- theme=gr.themes.Soft(primary_hue="blue", secondary_hue="green"),
77
- title="GihonTech - AI Conversation Assistant"
78
- ) as demo:
79
-
80
- gr.Markdown("# GihonTech AI Conversation Assistant")
81
- gr.Markdown("Chat with an AI powered by **DialoGPT-medium**")
82
-
83
- with gr.Row():
84
- with gr.Column(scale=3):
85
- chatbot = gr.Chatbot(label="Conversation", height=500)
86
-
87
- with gr.Row():
88
- msg = gr.Textbox(
89
- label="Your Message",
90
- placeholder="Type your message here...",
91
- lines=2,
92
- scale=4,
93
- )
94
- send = gr.Button("Send", variant="primary", scale=1)
95
-
96
- clear = gr.Button("Clear Chat", variant="secondary")
97
-
98
- with gr.Column(scale=1):
99
- gr.Markdown("### Example Questions")
100
- for q in example_questions:
101
- gr.Button(q[:40] + ("..." if len(q) > 40 else ""), size="sm").click(
102
- lambda x=q: x, outputs=msg
103
- )
104
- gr.Markdown("---")
105
- gr.Markdown("### Model Info")
106
- gr.Textbox(
107
- value="DialoGPT-medium: Loaded",
108
- label="Model Status",
109
- interactive=False,
110
- )
111
- gr.Markdown(
112
- """
113
- **Features**
114
- - Context-aware replies
115
- - Conversation memory
116
-
117
- **Tips**
118
- - Ask clear questions
119
- - Use *Clear Chat* to start over
120
- """
121
- )
122
-
123
- # Event wiring
124
- send.click(chat_fn, inputs=[msg, chatbot], outputs=[msg, chatbot])
125
- msg.submit(chat_fn, inputs=[msg, chatbot], outputs=[msg, chatbot])
126
- clear.click(lambda: ([], ""), outputs=[chatbot, msg])
127
-
128
- # -------------------------------------------------
129
- # 5. OPTIONAL: expose /lambda (same JSON format)
130
- # -------------------------------------------------
131
- fastapi_app = FastAPI()
132
-
133
- @fastapi_app.post("/lambda")
134
- async def lambda_endpoint(req: Request):
135
- payload = await req.json()
136
- # Gradio sends {"data": [...]} ; we accept anything
137
- user_msg = payload.get("data", [""])[0]
138
- # Use the same generation logic (no history for this endpoint)
139
- resp = generate_response(user_msg, [])
140
- return {"data": [resp]}
141
-
142
- demo.mount_app(fastapi_app) # makes /lambda reachable
143
-
144
- # -------------------------------------------------
145
- # 6. Launch with queue (critical for API!)
146
- # -------------------------------------------------
147
  if __name__ == "__main__":
148
- demo.queue().launch(
149
  server_name="0.0.0.0",
150
  server_port=7860,
151
- share=False,
152
- show_error=True,
153
  )
 
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
4
 
5
+ # Initialize model and tokenizer
6
+ model = None
7
+ tokenizer = None
8
+
9
+ print("🚀 Initializing DialoGPT-medium model...")
10
+
11
+ try:
12
+ print("📥 Loading DialoGPT-medium model...")
13
+ model_name = "microsoft/DialoGPT-medium"
14
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+ model = AutoModelForCausalLM.from_pretrained(model_name)
16
+ print("DialoGPT-medium model loaded successfully!")
17
+
18
+ # Add padding token if it doesn't exist
19
+ if tokenizer.pad_token is None:
20
+ tokenizer.pad_token = tokenizer.eos_token
21
+
22
+ except Exception as e:
23
+ print(f" Failed to load DialoGPT-medium model: {e}")
24
+ model = None
25
+ tokenizer = None
26
+
27
+ def respond(message, chat_history):
28
+ """Respond to user message using DialoGPT"""
29
+ if model is None or tokenizer is None:
30
+ return "Model not loaded. Please try again later."
31
+
32
+ # Build conversation history
33
+ conversation = ""
34
+ for turn in chat_history:
35
+ conversation += f"User: {turn[0]}\nBot: {turn[1]}\n"
36
+
37
+ conversation += f"User: {message}\nBot:"
38
+
39
+ # Encode and generate
40
+ inputs = tokenizer.encode(conversation, return_tensors='pt', max_length=1024, truncation=True)
41
+
42
  with torch.no_grad():
43
  outputs = model.generate(
44
  inputs,
45
+ max_length=len(inputs[0]) + 128,
46
  pad_token_id=tokenizer.eos_token_id,
47
  do_sample=True,
48
  temperature=0.7,
49
  top_k=50,
50
  top_p=0.95,
51
+ repetition_penalty=1.2
52
  )
53
+
54
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
55
  response = response.split("Bot:")[-1].strip()
56
+
57
+ # Clean response
58
  if "\nUser:" in response:
59
  response = response.split("\nUser:")[0]
60
+
61
+ chat_history.append((message, response))
62
+ return "", chat_history
63
+
64
+ # Create the chat interface
65
+ demo = gr.ChatInterface(
66
+ fn=respond,
67
+ title="💬 GihonTech AI Conversation Assistant",
68
+ description="Chat with an AI powered by Microsoft's DialoGPT-medium model",
69
+ examples=[
70
+ "Hello! How are you today?",
71
+ "What can you help me with?",
72
+ "Tell me about artificial intelligence",
73
+ "What's your favorite programming language?",
74
+ ],
75
+ cache_examples=False,
76
+ retry_btn=None,
77
+ undo_btn="↩️ Undo",
78
+ clear_btn="🗑️ Clear"
79
+ )
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  if __name__ == "__main__":
82
+ demo.launch(
83
  server_name="0.0.0.0",
84
  server_port=7860,
85
+ share=False
 
86
  )