Minte commited on
Commit
c3e5dc7
·
1 Parent(s): 07e06da

lets try now

Browse files
Files changed (1) hide show
  1. app.py +131 -64
app.py CHANGED
@@ -1,86 +1,153 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
4
 
5
- # Initialize model and tokenizer
6
- model = None
7
- tokenizer = None
8
-
9
- print("🚀 Initializing DialoGPT-medium model...")
10
-
11
- try:
12
- print("📥 Loading DialoGPT-medium model...")
13
- model_name = "microsoft/DialoGPT-medium"
14
- tokenizer = AutoTokenizer.from_pretrained(model_name)
15
- model = AutoModelForCausalLM.from_pretrained(model_name)
16
- print("DialoGPT-medium model loaded successfully!")
17
-
18
- # Add padding token if it doesn't exist
19
- if tokenizer.pad_token is None:
20
- tokenizer.pad_token = tokenizer.eos_token
21
-
22
- except Exception as e:
23
- print(f" Failed to load DialoGPT-medium model: {e}")
24
- model = None
25
- tokenizer = None
26
-
27
- def respond(message, chat_history):
28
- """Respond to user message using DialoGPT"""
29
- if model is None or tokenizer is None:
30
- return "Model not loaded. Please try again later."
31
-
32
- # Build conversation history
33
- conversation = ""
34
- for turn in chat_history:
35
- conversation += f"User: {turn[0]}\nBot: {turn[1]}\n"
36
-
37
- conversation += f"User: {message}\nBot:"
38
-
39
- # Encode and generate
40
- inputs = tokenizer.encode(conversation, return_tensors='pt', max_length=1024, truncation=True)
41
-
42
  with torch.no_grad():
43
  outputs = model.generate(
44
  inputs,
45
- max_length=len(inputs[0]) + 128,
46
  pad_token_id=tokenizer.eos_token_id,
47
  do_sample=True,
48
  temperature=0.7,
49
  top_k=50,
50
  top_p=0.95,
51
- repetition_penalty=1.2
52
  )
53
-
54
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
55
  response = response.split("Bot:")[-1].strip()
56
-
57
- # Clean response
58
  if "\nUser:" in response:
59
  response = response.split("\nUser:")[0]
60
-
61
- chat_history.append((message, response))
62
- return "", chat_history
63
-
64
- # Create the chat interface
65
- demo = gr.ChatInterface(
66
- fn=respond,
67
- title="💬 GihonTech AI Conversation Assistant",
68
- description="Chat with an AI powered by Microsoft's DialoGPT-medium model",
69
- examples=[
70
- "Hello! How are you today?",
71
- "What can you help me with?",
72
- "Tell me about artificial intelligence",
73
- "What's your favorite programming language?",
74
- ],
75
- cache_examples=False,
76
- retry_btn=None,
77
- undo_btn="↩️ Undo",
78
- clear_btn="🗑️ Clear"
79
- )
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  if __name__ == "__main__":
82
- demo.launch(
83
  server_name="0.0.0.0",
84
  server_port=7860,
85
- share=False
 
86
  )
 
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ from fastapi import FastAPI, Request
5
 
6
+ # -------------------------------------------------
7
+ # 1. Load model (same as your old code)
8
+ # -------------------------------------------------
9
+ print("Initializing DialoGPT-medium model...")
10
+ model_name = "microsoft/DialoGPT-medium"
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ model = AutoModelForCausalLM.from_pretrained(model_name)
13
+
14
+ if tokenizer.pad_token is None:
15
+ tokenizer.pad_token = tokenizer.eos_token
16
+
17
+ print("DialoGPT-medium loaded!")
18
+
19
+ # -------------------------------------------------
20
+ # 2. Generation helper (your old logic, cleaned up)
21
+ # -------------------------------------------------
22
+ def generate_response(message: str, chat_history: list):
23
+ if not message.strip():
24
+ return "Please enter a message."
25
+
26
+ # Build conversation string
27
+ conv = ""
28
+ for user, bot in chat_history:
29
+ conv += f"User: {user}\nBot: {bot}\n"
30
+ conv += f"User: {message}\nBot:"
31
+
32
+ # Encode
33
+ inputs = tokenizer.encode(conv, return_tensors="pt", max_length=1024, truncation=True)
34
+
35
+ # Generate
 
 
 
 
 
 
 
36
  with torch.no_grad():
37
  outputs = model.generate(
38
  inputs,
39
+ max_length=inputs.shape[1] + 128,
40
  pad_token_id=tokenizer.eos_token_id,
41
  do_sample=True,
42
  temperature=0.7,
43
  top_k=50,
44
  top_p=0.95,
45
+ repetition_penalty=1.2,
46
  )
47
+
48
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
49
  response = response.split("Bot:")[-1].strip()
 
 
50
  if "\nUser:" in response:
51
  response = response.split("\nUser:")[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ return response
54
+
55
+ # -------------------------------------------------
56
+ # 3. Gradio chat function (used by /run/predict)
57
+ # -------------------------------------------------
58
+ def chat_fn(message: str, history: list):
59
+ response = generate_response(message, history or [])
60
+ history.append((message, response))
61
+ return "", history # clear textbox, update chat
62
+
63
+ # -------------------------------------------------
64
+ # 4. Build the UI (your Blocks layout)
65
+ # -------------------------------------------------
66
+ example_questions = [
67
+ "Hello! How are you today?",
68
+ "What can you help me with?",
69
+ "Tell me about artificial intelligence",
70
+ "What's your favorite programming language?",
71
+ "Can you explain machine learning?",
72
+ "How does a neural network work?"
73
+ ]
74
+
75
+ with gr.Blocks(
76
+ theme=gr.themes.Soft(primary_hue="blue", secondary_hue="green"),
77
+ title="GihonTech - AI Conversation Assistant"
78
+ ) as demo:
79
+
80
+ gr.Markdown("# GihonTech AI Conversation Assistant")
81
+ gr.Markdown("Chat with an AI powered by **DialoGPT-medium**")
82
+
83
+ with gr.Row():
84
+ with gr.Column(scale=3):
85
+ chatbot = gr.Chatbot(label="Conversation", height=500)
86
+
87
+ with gr.Row():
88
+ msg = gr.Textbox(
89
+ label="Your Message",
90
+ placeholder="Type your message here...",
91
+ lines=2,
92
+ scale=4,
93
+ )
94
+ send = gr.Button("Send", variant="primary", scale=1)
95
+
96
+ clear = gr.Button("Clear Chat", variant="secondary")
97
+
98
+ with gr.Column(scale=1):
99
+ gr.Markdown("### Example Questions")
100
+ for q in example_questions:
101
+ gr.Button(q[:40] + ("..." if len(q) > 40 else ""), size="sm").click(
102
+ lambda x=q: x, outputs=msg
103
+ )
104
+ gr.Markdown("---")
105
+ gr.Markdown("### Model Info")
106
+ gr.Textbox(
107
+ value="DialoGPT-medium: Loaded",
108
+ label="Model Status",
109
+ interactive=False,
110
+ )
111
+ gr.Markdown(
112
+ """
113
+ **Features**
114
+ - Context-aware replies
115
+ - Conversation memory
116
+
117
+ **Tips**
118
+ - Ask clear questions
119
+ - Use *Clear Chat* to start over
120
+ """
121
+ )
122
+
123
+ # Event wiring
124
+ send.click(chat_fn, inputs=[msg, chatbot], outputs=[msg, chatbot])
125
+ msg.submit(chat_fn, inputs=[msg, chatbot], outputs=[msg, chatbot])
126
+ clear.click(lambda: ([], ""), outputs=[chatbot, msg])
127
+
128
+ # -------------------------------------------------
129
+ # 5. OPTIONAL: expose /lambda (same JSON format)
130
+ # -------------------------------------------------
131
+ fastapi_app = FastAPI()
132
+
133
+ @fastapi_app.post("/lambda")
134
+ async def lambda_endpoint(req: Request):
135
+ payload = await req.json()
136
+ # Gradio sends {"data": [...]} ; we accept anything
137
+ user_msg = payload.get("data", [""])[0]
138
+ # Use the same generation logic (no history for this endpoint)
139
+ resp = generate_response(user_msg, [])
140
+ return {"data": [resp]}
141
+
142
+ demo.mount_app(fastapi_app) # makes /lambda reachable
143
+
144
+ # -------------------------------------------------
145
+ # 6. Launch with queue (critical for API!)
146
+ # -------------------------------------------------
147
  if __name__ == "__main__":
148
+ demo.queue().launch(
149
  server_name="0.0.0.0",
150
  server_port=7860,
151
+ share=False,
152
+ show_error=True,
153
  )