mvp
Browse files- .gitattributes +2 -0
- 1.png +0 -0
- README.md +15 -15
- __init__.py +0 -0
- app.py +1330 -722
- app_pro.py +840 -0
- audio_127.0.0.1.wav +3 -0
- image_127.0.0.1.jpg +0 -0
- requirements.txt +8 -4
- se_app.py +232 -0
- temp_audio.wav +3 -0
- todogen_LLM_config.yaml +11 -1
- tools.py +828 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
audio_127.0.0.1.wav filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
temp_audio.wav filter=lfs diff=lfs merge=lfs -text
|
1.png
ADDED
|
README.md
CHANGED
|
@@ -1,16 +1,16 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: ToDoAgent
|
| 3 |
-
emoji: 💬
|
| 4 |
-
colorFrom: yellow
|
| 5 |
-
colorTo: purple
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 5.32.0
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
-
license: bsd
|
| 11 |
-
short_description: AI Agent filters, creates to-do list and reminds smartly
|
| 12 |
-
tags: ['agent-demo-track']
|
| 13 |
-
demo: https://youtu.be/S-wh3Psx15M?si=Wiq7EzmE3dmBvLKQ
|
| 14 |
-
---
|
| 15 |
-
|
| 16 |
An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: ToDoAgent
|
| 3 |
+
emoji: 💬
|
| 4 |
+
colorFrom: yellow
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.32.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: bsd
|
| 11 |
+
short_description: AI Agent filters, creates to-do list and reminds smartly
|
| 12 |
+
tags: ['agent-demo-track']
|
| 13 |
+
demo: https://youtu.be/S-wh3Psx15M?si=Wiq7EzmE3dmBvLKQ
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
|
__init__.py
ADDED
|
File without changes
|
app.py
CHANGED
|
@@ -1,722 +1,1330 @@
|
|
| 1 |
-
import gradio as gr
|
| 2 |
-
import json
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
import yaml
|
| 5 |
-
import re
|
| 6 |
-
import logging
|
| 7 |
-
import io
|
| 8 |
-
import sys
|
| 9 |
-
import
|
| 10 |
-
|
| 11 |
-
import
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
def
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
try:
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
return
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
#
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
}
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
return
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
|
| 642 |
-
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
if not
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import json
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import yaml
|
| 5 |
+
import re
|
| 6 |
+
import logging
|
| 7 |
+
import io
|
| 8 |
+
import sys
|
| 9 |
+
import os
|
| 10 |
+
import re
|
| 11 |
+
from datetime import datetime, timezone, timedelta
|
| 12 |
+
import requests
|
| 13 |
+
|
| 14 |
+
from tools import FileUploader, ResultExtractor, audio_to_str, image_to_str, azure_speech_to_text #gege的多模态
|
| 15 |
+
import numpy as np
|
| 16 |
+
from scipy.io.wavfile import write as write_wav
|
| 17 |
+
from PIL import Image
|
| 18 |
+
|
| 19 |
+
# 指定保存文件的相对路径
|
| 20 |
+
SAVE_DIR = 'download' # 相对路径
|
| 21 |
+
os.makedirs(SAVE_DIR, exist_ok=True) # 确保目录存在
|
| 22 |
+
|
| 23 |
+
def save_audio(audio, filename):
|
| 24 |
+
"""保存音频为.wav文件"""
|
| 25 |
+
sample_rate, audio_data = audio
|
| 26 |
+
write_wav(filename, sample_rate, audio_data)
|
| 27 |
+
|
| 28 |
+
def save_image(image, filename):
|
| 29 |
+
"""保存图片为.jpg文件"""
|
| 30 |
+
img = Image.fromarray(image.astype('uint8'))
|
| 31 |
+
img.save(filename)
|
| 32 |
+
|
| 33 |
+
# --- IP获取功能 (从 se_app.py 迁移) ---
|
| 34 |
+
def get_client_ip(request: gr.Request, debug_mode=False):
|
| 35 |
+
"""获取客户端真实IP地址"""
|
| 36 |
+
if request:
|
| 37 |
+
# 从请求头中获取真实IP(考虑代理情况)
|
| 38 |
+
x_forwarded_for = request.headers.get("x-forwarded-for", "")
|
| 39 |
+
if x_forwarded_for:
|
| 40 |
+
client_ip = x_forwarded_for.split(",")[0]
|
| 41 |
+
else:
|
| 42 |
+
client_ip = request.client.host
|
| 43 |
+
if debug_mode:
|
| 44 |
+
print(f"Debug: Client IP detected as {client_ip}")
|
| 45 |
+
return client_ip
|
| 46 |
+
return "unknown"
|
| 47 |
+
|
| 48 |
+
# --- 配置加载 (从 config_loader.py 迁移并简化) ---
|
| 49 |
+
CONFIG = None
|
| 50 |
+
HF_CONFIG_PATH = Path(__file__).parent / "todogen_LLM_config.yaml"
|
| 51 |
+
|
| 52 |
+
def load_hf_config():
|
| 53 |
+
global CONFIG
|
| 54 |
+
if CONFIG is None:
|
| 55 |
+
try:
|
| 56 |
+
with open(HF_CONFIG_PATH, 'r', encoding='utf-8') as f:
|
| 57 |
+
CONFIG = yaml.safe_load(f)
|
| 58 |
+
print(f"✅ 配置已加载: {HF_CONFIG_PATH}")
|
| 59 |
+
except FileNotFoundError:
|
| 60 |
+
print(f"❌ 错误: 配置文件 {HF_CONFIG_PATH} 未找到。请确保它在 hf 目录下。")
|
| 61 |
+
CONFIG = {} # 提供一个空配置以避免后续错误
|
| 62 |
+
except Exception as e:
|
| 63 |
+
print(f"❌ 加载配置文件 {HF_CONFIG_PATH} 时出错: {e}")
|
| 64 |
+
CONFIG = {}
|
| 65 |
+
return CONFIG
|
| 66 |
+
|
| 67 |
+
def get_hf_openai_config():
|
| 68 |
+
config = load_hf_config()
|
| 69 |
+
return config.get('openai', {})
|
| 70 |
+
|
| 71 |
+
def get_hf_openai_filter_config():
|
| 72 |
+
config = load_hf_config()
|
| 73 |
+
return config.get('openai_filter', {})
|
| 74 |
+
|
| 75 |
+
def get_hf_xunfei_config():
|
| 76 |
+
config = load_hf_config()
|
| 77 |
+
return config.get('xunfei', {})
|
| 78 |
+
|
| 79 |
+
def get_hf_azure_speech_config():
|
| 80 |
+
config = load_hf_config()
|
| 81 |
+
return config.get('azure_speech', {})
|
| 82 |
+
|
| 83 |
+
def get_hf_paths_config():
|
| 84 |
+
config = load_hf_config()
|
| 85 |
+
# 在hf环境下,路径相对于hf目录
|
| 86 |
+
base = Path(__file__).resolve().parent
|
| 87 |
+
paths_cfg = config.get('paths', {})
|
| 88 |
+
return {
|
| 89 |
+
'base_dir': base,
|
| 90 |
+
'prompt_template': base / paths_cfg.get('prompt_template', 'prompt_template.txt'),
|
| 91 |
+
'true_positive_examples': base / paths_cfg.get('true_positive_examples', 'TruePositive_few_shot.txt'),
|
| 92 |
+
'false_positive_examples': base / paths_cfg.get('false_positive_examples', 'FalsePositive_few_shot.txt'),
|
| 93 |
+
# data_dir 和 logging_dir 在 app.py 中可能用途不大,除非需要保存 LLM 输出
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
# --- LLM Client 初始化 (使用 NVIDIA API) ---
|
| 97 |
+
# 从配置加载 NVIDIA API 的 base_url, api_key 和 model
|
| 98 |
+
llm_config = get_hf_openai_config()
|
| 99 |
+
NVIDIA_API_BASE_URL = llm_config.get('base_url')
|
| 100 |
+
NVIDIA_API_KEY = llm_config.get('api_key')
|
| 101 |
+
NVIDIA_MODEL_NAME = llm_config.get('model')
|
| 102 |
+
|
| 103 |
+
# 从配置加载 Filter API 的 base_url, api_key 和 model
|
| 104 |
+
filter_config = get_hf_openai_filter_config()
|
| 105 |
+
Filter_API_BASE_URL = filter_config.get('base_url_filter')
|
| 106 |
+
Filter_API_KEY = filter_config.get('api_key_filter')
|
| 107 |
+
Filter_MODEL_NAME = filter_config.get('model_filter')
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
if not NVIDIA_API_BASE_URL or not NVIDIA_API_KEY or not NVIDIA_MODEL_NAME:
|
| 111 |
+
print("❌ 错误: NVIDIA API 配置不完整。请检查 todogen_LLM_config.yaml 中的 openai 部分。")
|
| 112 |
+
# 提供默认值或退出,以便程序可以继续运行,但LLM调用会失败
|
| 113 |
+
NVIDIA_API_BASE_URL = ""
|
| 114 |
+
NVIDIA_API_KEY = ""
|
| 115 |
+
NVIDIA_MODEL_NAME = ""
|
| 116 |
+
|
| 117 |
+
if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME:
|
| 118 |
+
print("❌ 错误: Filter API 配置不完整。请检查 todogen_LLM_config.yaml 中的 openai_filter 部分。")
|
| 119 |
+
# 提供默认值或退出,以便程序可以继续运行,但Filter LLM调用会失败
|
| 120 |
+
Filter_API_BASE_URL = ""
|
| 121 |
+
Filter_API_KEY = ""
|
| 122 |
+
Filter_MODEL_NAME = ""
|
| 123 |
+
|
| 124 |
+
# --- 日志配置 (简化版) ---
|
| 125 |
+
# 修正后的标准流编码设置 (如果需要,但 Gradio 通常处理自己的输出)
|
| 126 |
+
# sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8')
|
| 127 |
+
# sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', write_through=True)
|
| 128 |
+
# sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', write_through=True)
|
| 129 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 130 |
+
logger = logging.getLogger(__name__)
|
| 131 |
+
|
| 132 |
+
# --- Prompt 和 Few-Shot 加载 (从 todogen_llm.py 迁移并适配) ---
|
| 133 |
+
def load_single_few_shot_file_hf(file_path: Path) -> str:
|
| 134 |
+
"""加载单个 few-shot 文件并转义 { 和 }"""
|
| 135 |
+
try:
|
| 136 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 137 |
+
content = f.read()
|
| 138 |
+
escaped_content = content.replace('{', '{{').replace('}', '}}')
|
| 139 |
+
logger.info(f"✅ 成功加载并转义文件: {file_path}")
|
| 140 |
+
return escaped_content
|
| 141 |
+
except FileNotFoundError:
|
| 142 |
+
logger.warning(f"⚠️ 警告:找不到文件 {file_path}。")
|
| 143 |
+
return ""
|
| 144 |
+
except Exception as e:
|
| 145 |
+
logger.error(f"❌ 加载文件 {file_path} 时出错: {e}", exc_info=True)
|
| 146 |
+
return ""
|
| 147 |
+
|
| 148 |
+
PROMPT_TEMPLATE_CONTENT = ""
|
| 149 |
+
TRUE_POSITIVE_EXAMPLES_CONTENT = ""
|
| 150 |
+
FALSE_POSITIVE_EXAMPLES_CONTENT = ""
|
| 151 |
+
|
| 152 |
+
def load_prompt_data_hf():
|
| 153 |
+
global PROMPT_TEMPLATE_CONTENT, TRUE_POSITIVE_EXAMPLES_CONTENT, FALSE_POSITIVE_EXAMPLES_CONTENT
|
| 154 |
+
paths = get_hf_paths_config()
|
| 155 |
+
try:
|
| 156 |
+
with open(paths['prompt_template'], 'r', encoding='utf-8') as f:
|
| 157 |
+
PROMPT_TEMPLATE_CONTENT = f.read()
|
| 158 |
+
logger.info(f"✅ 成功加载 Prompt 模板文件: {paths['prompt_template']}")
|
| 159 |
+
except FileNotFoundError:
|
| 160 |
+
logger.error(f"❌ 错误:找不到 Prompt 模板文件:{paths['prompt_template']}")
|
| 161 |
+
PROMPT_TEMPLATE_CONTENT = "Error: Prompt template not found."
|
| 162 |
+
|
| 163 |
+
TRUE_POSITIVE_EXAMPLES_CONTENT = load_single_few_shot_file_hf(paths['true_positive_examples'])
|
| 164 |
+
FALSE_POSITIVE_EXAMPLES_CONTENT = load_single_few_shot_file_hf(paths['false_positive_examples'])
|
| 165 |
+
|
| 166 |
+
# 应用启动时加载 prompts
|
| 167 |
+
load_prompt_data_hf()
|
| 168 |
+
|
| 169 |
+
# --- JSON 解析器 (从 todogen_llm.py 迁移) ---
|
| 170 |
+
def json_parser(text: str) -> dict:
|
| 171 |
+
# 改进的JSON解析器,更健壮地处理各种格式
|
| 172 |
+
logger.info(f"Attempting to parse: {text[:200]}...")
|
| 173 |
+
try:
|
| 174 |
+
# 1. 尝试直接将整个文本作为JSON解析
|
| 175 |
+
try:
|
| 176 |
+
parsed_data = json.loads(text)
|
| 177 |
+
# 使用_process_parsed_json处理解析结果
|
| 178 |
+
return _process_parsed_json(parsed_data)
|
| 179 |
+
except json.JSONDecodeError:
|
| 180 |
+
pass # 如果直接解析失败,继续尝试提取代码块
|
| 181 |
+
|
| 182 |
+
# 2. 尝试从 ```json ... ``` 代码块中提取和解析
|
| 183 |
+
match = re.search(r'```(?:json)?\n(.*?)```', text, re.DOTALL)
|
| 184 |
+
if match:
|
| 185 |
+
json_str = match.group(1).strip()
|
| 186 |
+
# 修复常见的JSON格式问题
|
| 187 |
+
json_str = re.sub(r',\s*]', ']', json_str)
|
| 188 |
+
json_str = re.sub(r',\s*}', '}', json_str)
|
| 189 |
+
try:
|
| 190 |
+
parsed_data = json.loads(json_str)
|
| 191 |
+
# 使用_process_parsed_json处理解析结果
|
| 192 |
+
return _process_parsed_json(parsed_data)
|
| 193 |
+
except json.JSONDecodeError as e_block:
|
| 194 |
+
logger.warning(f"JSONDecodeError from code block: {e_block} while parsing: {json_str[:200]}")
|
| 195 |
+
# 如果从代码块解析也失败,则继续
|
| 196 |
+
|
| 197 |
+
# 3. 尝试查找最外层的 '{...}' 或 '[...]' 作为JSON
|
| 198 |
+
# 先尝试查找数组格式 [...]
|
| 199 |
+
array_match = re.search(r'\[\s*\{.*?\}\s*(?:,\s*\{.*?\}\s*)*\]', text, re.DOTALL)
|
| 200 |
+
if array_match:
|
| 201 |
+
potential_json = array_match.group(0).strip()
|
| 202 |
+
try:
|
| 203 |
+
parsed_data = json.loads(potential_json)
|
| 204 |
+
# 使用_process_parsed_json处理解析结果
|
| 205 |
+
return _process_parsed_json(parsed_data)
|
| 206 |
+
except json.JSONDecodeError:
|
| 207 |
+
logger.warning(f"Could not parse potential JSON array: {potential_json[:200]}")
|
| 208 |
+
pass
|
| 209 |
+
|
| 210 |
+
# 再尝试查找单个对象格式 {...}
|
| 211 |
+
object_match = re.search(r'\{.*?\}', text, re.DOTALL)
|
| 212 |
+
if object_match:
|
| 213 |
+
potential_json = object_match.group(0).strip()
|
| 214 |
+
try:
|
| 215 |
+
parsed_data = json.loads(potential_json)
|
| 216 |
+
# 使用_process_parsed_json处理解析结果
|
| 217 |
+
return _process_parsed_json(parsed_data)
|
| 218 |
+
except json.JSONDecodeError:
|
| 219 |
+
logger.warning(f"Could not parse potential JSON object: {potential_json[:200]}")
|
| 220 |
+
pass
|
| 221 |
+
|
| 222 |
+
# 4. 如果所有尝试都失败,返回错误信息
|
| 223 |
+
logger.error(f"Failed to find or parse JSON block in text: {text[:500]}") # 增加日志长度
|
| 224 |
+
return {"error": "No valid JSON block found or failed to parse", "raw_text": text}
|
| 225 |
+
|
| 226 |
+
except Exception as e: # 捕获所有其他意外错误
|
| 227 |
+
logger.error(f"Unexpected error in json_parser: {e} for text: {text[:200]}", exc_info=True)
|
| 228 |
+
return {"error": f"Unexpected error in json_parser: {e}", "raw_text": text}
|
| 229 |
+
|
| 230 |
+
def _process_parsed_json(parsed_data):
|
| 231 |
+
"""处理解析后的JSON数据,确保返回有效的数据结构"""
|
| 232 |
+
try:
|
| 233 |
+
# 如果解析结果是空列表,���回包含空字典的列表
|
| 234 |
+
if isinstance(parsed_data, list):
|
| 235 |
+
if not parsed_data:
|
| 236 |
+
logger.warning("JSON解析结果为空列表,返回包含空字典的列表")
|
| 237 |
+
return [{}]
|
| 238 |
+
|
| 239 |
+
# 确保列表中的每个元素都是字典
|
| 240 |
+
processed_list = []
|
| 241 |
+
for item in parsed_data:
|
| 242 |
+
if isinstance(item, dict):
|
| 243 |
+
processed_list.append(item)
|
| 244 |
+
else:
|
| 245 |
+
# 如果不是字典,将其转换为字典
|
| 246 |
+
try:
|
| 247 |
+
processed_list.append({"content": str(item)})
|
| 248 |
+
except:
|
| 249 |
+
processed_list.append({"content": "无法转换的项目"})
|
| 250 |
+
|
| 251 |
+
# 如果处理后的列表为空,返回包含空字典的列表
|
| 252 |
+
if not processed_list:
|
| 253 |
+
logger.warning("处理后的JSON列表为空,返回包含空字典的列表")
|
| 254 |
+
return [{}]
|
| 255 |
+
|
| 256 |
+
return processed_list
|
| 257 |
+
|
| 258 |
+
# 如果是字典,直接返回
|
| 259 |
+
elif isinstance(parsed_data, dict):
|
| 260 |
+
return parsed_data
|
| 261 |
+
|
| 262 |
+
# 如果是其他类型,转换为字典
|
| 263 |
+
else:
|
| 264 |
+
logger.warning(f"JSON解析结果不是列表或字典,而是{type(parsed_data)},转换为字典")
|
| 265 |
+
return {"content": str(parsed_data)}
|
| 266 |
+
|
| 267 |
+
except Exception as e:
|
| 268 |
+
logger.error(f"处理解析后的JSON数据时出错: {e}")
|
| 269 |
+
return {"error": f"Error processing parsed JSON: {e}"}
|
| 270 |
+
|
| 271 |
+
# --- Filter 模块的 System Prompt (从 filter_message/libs.py 迁移) ---
|
| 272 |
+
FILTER_SYSTEM_PROMPT = """
|
| 273 |
+
# 角色
|
| 274 |
+
你是一个专业的短信内容分析助手,根据输入判断内容的类型及可信度,为用户使用信息提供依据和便利。
|
| 275 |
+
|
| 276 |
+
# 任务
|
| 277 |
+
对于输入的多条数据,分析每一条数据内容(主键:`message_id`)属于【物流取件、缴费充值、待付(还)款、会议邀约、其他】的可能性百分比。
|
| 278 |
+
主要对于聊天、问候、回执、结果通知、上月账单等信息不需要收件人进行下一步处理的信息,直接归到其他类进行忽略
|
| 279 |
+
|
| 280 |
+
# 要求
|
| 281 |
+
1. 以json格式输出
|
| 282 |
+
2. content简洁提炼关键词,字符数<20以内
|
| 283 |
+
3. 输入条数和输出条数完全一样
|
| 284 |
+
|
| 285 |
+
# 输出示例
|
| 286 |
+
```
|
| 287 |
+
[
|
| 288 |
+
{"message_id":"1111111","content":"账单805.57元待还","物流取件":0,"欠费缴纳":99,"待付(还)款":1,"会议邀约":0,"其他":0, "分类":"欠费缴纳"},
|
| 289 |
+
{"message_id":"222222","content":"邀请你加入飞书视频会议","物流取件":0,"欠费缴纳":0,"待付(还)款":1,"会议邀约":100,"其他":0, "分类":"会议"}
|
| 290 |
+
]
|
| 291 |
+
|
| 292 |
+
```
|
| 293 |
+
"""
|
| 294 |
+
|
| 295 |
+
# --- Filter 核心逻辑 (从ToDoAgent集成) ---
|
| 296 |
+
def filter_message_with_llm(text_input: str, message_id: str = "user_input_001"):
|
| 297 |
+
logger.info(f"调用 filter_message_with_llm 处理输入: {text_input} (msg_id: {message_id})")
|
| 298 |
+
|
| 299 |
+
# 构造发送给 LLM 的消息
|
| 300 |
+
# filter 模块的 send_llm_with_prompt 接收的是 tuple[tuple] 格式的数据
|
| 301 |
+
# 这里我们只有一个文本输入,需要模拟成那种格式
|
| 302 |
+
mock_data = [(text_input, message_id)]
|
| 303 |
+
|
| 304 |
+
# 使用与ToDoAgent相同的system prompt
|
| 305 |
+
system_prompt = """
|
| 306 |
+
# 角色
|
| 307 |
+
你是一个专业的短信内容分析助手,根据输入判断内容的类型及可信度,为用户使用信息提供依据和便利。
|
| 308 |
+
|
| 309 |
+
# 任务
|
| 310 |
+
对于输入的多条数据,分析每一条数据内容(主键:`message_id`)属于【物流取件、缴费充值、待付(还)款、会议邀约、其他】的可能性百分比。
|
| 311 |
+
主要对于聊天、问候、回执、结果通知、上月账单等信息不需要收件人进行下一步处理的信息,直接归到其他类进行忽略
|
| 312 |
+
|
| 313 |
+
# 要求
|
| 314 |
+
1. 以json格式输出
|
| 315 |
+
2. content简洁提炼关键词,字符数<20以内
|
| 316 |
+
3. 输入条数和输出条数完全一样
|
| 317 |
+
|
| 318 |
+
# 输出示例
|
| 319 |
+
```
|
| 320 |
+
[
|
| 321 |
+
{"message_id":"1111111","content":"账单805.57元待还","物流取件":0,"欠费缴纳":99,"待付(还)款":1,"会议邀约":0,"其他":0, "分类":"欠费缴纳"},
|
| 322 |
+
{"message_id":"222222","content":"邀请你加入飞书视频会议","物流取件":0,"欠费缴纳":0,"待付(还)款":1,"会议邀约":100,"其他":0, "分类":"会议邀约"}
|
| 323 |
+
]
|
| 324 |
+
```
|
| 325 |
+
"""
|
| 326 |
+
|
| 327 |
+
llm_messages = [
|
| 328 |
+
{"role": "system", "content": system_prompt},
|
| 329 |
+
{"role": "user", "content": str(mock_data)}
|
| 330 |
+
]
|
| 331 |
+
|
| 332 |
+
try:
|
| 333 |
+
if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME:
|
| 334 |
+
logger.error("Filter API 配置不完整,无法调用 Filter LLM。")
|
| 335 |
+
return [{"error": "Filter API configuration incomplete", "-": "-"}]
|
| 336 |
+
|
| 337 |
+
headers = {
|
| 338 |
+
"Authorization": f"Bearer {Filter_API_KEY}",
|
| 339 |
+
"Accept": "application/json"
|
| 340 |
+
}
|
| 341 |
+
payload = {
|
| 342 |
+
"model": Filter_MODEL_NAME,
|
| 343 |
+
"messages": llm_messages,
|
| 344 |
+
"temperature": 0.0, # 为提高准确率,温度为0(与ToDoAgent一致)
|
| 345 |
+
"top_p": 0.95,
|
| 346 |
+
"max_tokens": 1024,
|
| 347 |
+
"stream": False
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
api_url = f"{Filter_API_BASE_URL}/chat/completions"
|
| 351 |
+
|
| 352 |
+
try:
|
| 353 |
+
response = requests.post(api_url, headers=headers, json=payload)
|
| 354 |
+
response.raise_for_status() # 检查 HTTP 错误
|
| 355 |
+
raw_llm_response = response.json()["choices"][0]["message"]["content"]
|
| 356 |
+
logger.info(f"LLM 原始回复 (部分): {raw_llm_response[:200]}...")
|
| 357 |
+
except requests.exceptions.RequestException as e:
|
| 358 |
+
logger.error(f"调用 Filter API 失败: {e}")
|
| 359 |
+
return [{"error": f"Filter API call failed: {e}", "-": "-"}]
|
| 360 |
+
logger.info(f"Filter LLM 原始回复 (部分): {raw_llm_response[:200]}...")
|
| 361 |
+
|
| 362 |
+
# 解析 LLM 响应
|
| 363 |
+
# 移除可能的代码块标记
|
| 364 |
+
raw_llm_response = raw_llm_response.replace("```json", "").replace("```", "")
|
| 365 |
+
parsed_filter_data = json_parser(raw_llm_response)
|
| 366 |
+
|
| 367 |
+
if "error" in parsed_filter_data:
|
| 368 |
+
logger.error(f"解析 Filter LLM 响应失败: {parsed_filter_data['error']}")
|
| 369 |
+
return [{"error": f"Filter LLM response parsing error: {parsed_filter_data['error']}"}]
|
| 370 |
+
|
| 371 |
+
# 返回解析后的数据
|
| 372 |
+
if isinstance(parsed_filter_data, list) and parsed_filter_data:
|
| 373 |
+
# 应用规则:如果分类是欠费缴纳且内容包含"缴费支出",归类为"其他"
|
| 374 |
+
for item in parsed_filter_data:
|
| 375 |
+
if isinstance(item, dict) and item.get("分类") == "欠费缴纳" and "缴费支出" in item.get("content", ""):
|
| 376 |
+
item["分类"] = "其他"
|
| 377 |
+
|
| 378 |
+
# 检查是否有遗漏的消息ID(ToDoAgent的补充逻辑)
|
| 379 |
+
request_id_list = {message_id}
|
| 380 |
+
response_id_list = {item.get('message_id') for item in parsed_filter_data if isinstance(item, dict)}
|
| 381 |
+
diff = request_id_list - response_id_list
|
| 382 |
+
|
| 383 |
+
if diff:
|
| 384 |
+
logger.warning(f"Filter LLM 响应中有遗漏的消息ID: {diff}")
|
| 385 |
+
# 对于遗漏的消息,添加一个默认分类为"其他"的项
|
| 386 |
+
for missed_id in diff:
|
| 387 |
+
parsed_filter_data.append({
|
| 388 |
+
"message_id": missed_id,
|
| 389 |
+
"content": text_input[:20], # 截取前20个字符作为content
|
| 390 |
+
"物流取件": 0,
|
| 391 |
+
"欠费缴纳": 0,
|
| 392 |
+
"待付(还)款": 0,
|
| 393 |
+
"会议邀约": 0,
|
| 394 |
+
"其他": 100,
|
| 395 |
+
"分类": "其他"
|
| 396 |
+
})
|
| 397 |
+
|
| 398 |
+
return parsed_filter_data
|
| 399 |
+
else:
|
| 400 |
+
logger.warning(f"Filter LLM 返回空列表或非预期格式: {parsed_filter_data}")
|
| 401 |
+
# 返回默认分类为"其他"的项
|
| 402 |
+
return [{
|
| 403 |
+
"message_id": message_id,
|
| 404 |
+
"content": text_input[:20], # 截取前20个字符作为content
|
| 405 |
+
"物流取件": 0,
|
| 406 |
+
"欠费缴纳": 0,
|
| 407 |
+
"待付(还)款": 0,
|
| 408 |
+
"会议邀约": 0,
|
| 409 |
+
"其他": 100,
|
| 410 |
+
"分类": "其他",
|
| 411 |
+
"error": "Filter LLM returned empty or unexpected format"
|
| 412 |
+
}]
|
| 413 |
+
|
| 414 |
+
except Exception as e:
|
| 415 |
+
logger.exception(f"调用 Filter LLM 或解析时发生错误 (filter_message_with_llm)")
|
| 416 |
+
return [{
|
| 417 |
+
"message_id": message_id,
|
| 418 |
+
"content": text_input[:20], # 截取前20个字符作为content
|
| 419 |
+
"物流取件": 0,
|
| 420 |
+
"欠费缴纳": 0,
|
| 421 |
+
"待付(还)款": 0,
|
| 422 |
+
"会议邀约": 0,
|
| 423 |
+
"其他": 100,
|
| 424 |
+
"分类": "其他",
|
| 425 |
+
"error": f"Filter LLM call/parse error: {str(e)}"
|
| 426 |
+
}]
|
| 427 |
+
|
| 428 |
+
# --- ToDo List 生成核心逻辑 (使用迁移的代码) ---
|
| 429 |
+
def generate_todolist_from_text(text_input: str, message_id: str = "user_input_001"):
|
| 430 |
+
"""根据输入文本生成 ToDoList (使用迁移的逻辑)"""
|
| 431 |
+
logger.info(f"调用 generate_todolist_from_text 处理输入: {text_input} (msg_id: {message_id})")
|
| 432 |
+
|
| 433 |
+
if not PROMPT_TEMPLATE_CONTENT or "Error:" in PROMPT_TEMPLATE_CONTENT:
|
| 434 |
+
logger.error("Prompt 模板未正确加载,无法生成 ToDoList。")
|
| 435 |
+
return [["error", "Prompt template not loaded", "-"]]
|
| 436 |
+
|
| 437 |
+
current_time_iso = datetime.now(timezone.utc).isoformat()
|
| 438 |
+
# 转义输入内容中的 { 和 }
|
| 439 |
+
content_escaped = text_input.replace('{', '{{').replace('}', '}}')
|
| 440 |
+
|
| 441 |
+
# 构造 prompt
|
| 442 |
+
formatted_prompt = PROMPT_TEMPLATE_CONTENT.format(
|
| 443 |
+
true_positive_examples=TRUE_POSITIVE_EXAMPLES_CONTENT,
|
| 444 |
+
false_positive_examples=FALSE_POSITIVE_EXAMPLES_CONTENT,
|
| 445 |
+
current_time=current_time_iso,
|
| 446 |
+
message_id=message_id,
|
| 447 |
+
content_escaped=content_escaped
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
# 添加明确的JSON输出指令
|
| 451 |
+
enhanced_prompt = formatted_prompt + """
|
| 452 |
+
|
| 453 |
+
# 重要提示
|
| 454 |
+
请确保你的回复是有效的JSON格式,并且只包含JSON内容。不要添加任何额外的解释或文本。
|
| 455 |
+
你的回复应该严格按照上面的输出示例格式,只包含JSON对象,不要有任何其他文本。
|
| 456 |
+
"""
|
| 457 |
+
|
| 458 |
+
# 构造发送给 LLM 的消息
|
| 459 |
+
llm_messages = [
|
| 460 |
+
{"role": "user", "content": enhanced_prompt}
|
| 461 |
+
]
|
| 462 |
+
|
| 463 |
+
logger.info(f"发送给 LLM 的消息 (部分): {str(llm_messages)[:300]}...")
|
| 464 |
+
|
| 465 |
+
try:
|
| 466 |
+
# 根据输入文本智能生成 ToDo List
|
| 467 |
+
# 如果是移动话费充值提醒类消息
|
| 468 |
+
if ("充值" in text_input or "缴费" in text_input) and ("移动" in text_input or "话费" in text_input or "余额" in text_input):
|
| 469 |
+
# 直接生成待办事项,不调用API
|
| 470 |
+
todo_item = {
|
| 471 |
+
message_id: {
|
| 472 |
+
"is_todo": True,
|
| 473 |
+
"end_time": (datetime.now(timezone.utc) + timedelta(days=3)).isoformat(),
|
| 474 |
+
"location": "线上:中国移动APP",
|
| 475 |
+
"todo_content": "缴纳话费",
|
| 476 |
+
"urgency": "important"
|
| 477 |
+
}
|
| 478 |
+
}
|
| 479 |
+
|
| 480 |
+
# 转换为表格显示格式 - 合并为一行
|
| 481 |
+
todo_content = "缴纳话费"
|
| 482 |
+
end_time = todo_item[message_id]["end_time"].split("T")[0]
|
| 483 |
+
location = todo_item[message_id]["location"]
|
| 484 |
+
|
| 485 |
+
# 合并所有信息到任务内容中
|
| 486 |
+
combined_content = f"{todo_content} (截止时间: {end_time}, 地点: {location})"
|
| 487 |
+
|
| 488 |
+
output_for_df = []
|
| 489 |
+
output_for_df.append([1, combined_content, "重要"])
|
| 490 |
+
|
| 491 |
+
return output_for_df
|
| 492 |
+
|
| 493 |
+
# 如果是会议邀约类消息
|
| 494 |
+
elif "会议" in text_input and ("邀请" in text_input or "参加" in text_input):
|
| 495 |
+
# 提取可能的会议时间
|
| 496 |
+
meeting_time = None
|
| 497 |
+
meeting_pattern = r'(\d{1,2}[月/-]\d{1,2}[日号]?\s*\d{1,2}[点:]\d{0,2}|\d{4}[年/-]\d{1,2}[月/-]\d{1,2}[日号]?\s*\d{1,2}[点:]\d{0,2})'
|
| 498 |
+
meeting_match = re.search(meeting_pattern, text_input)
|
| 499 |
+
|
| 500 |
+
if meeting_match:
|
| 501 |
+
# 简单处理,实际应用中应该更精确地解析日期时间
|
| 502 |
+
meeting_time = (datetime.now(timezone.utc) + timedelta(days=1, hours=2)).isoformat()
|
| 503 |
+
else:
|
| 504 |
+
meeting_time = (datetime.now(timezone.utc) + timedelta(days=1)).isoformat()
|
| 505 |
+
|
| 506 |
+
todo_item = {
|
| 507 |
+
message_id: {
|
| 508 |
+
"is_todo": True,
|
| 509 |
+
"end_time": meeting_time,
|
| 510 |
+
"location": "线上:会议软件",
|
| 511 |
+
"todo_content": "参加会议",
|
| 512 |
+
"urgency": "important"
|
| 513 |
+
}
|
| 514 |
+
}
|
| 515 |
+
|
| 516 |
+
# 转换为表格显示格式 - 合并为一行
|
| 517 |
+
todo_content = "参加会议"
|
| 518 |
+
end_time = todo_item[message_id]["end_time"].split("T")[0]
|
| 519 |
+
location = todo_item[message_id]["location"]
|
| 520 |
+
|
| 521 |
+
# 合并所有信息到任务内容中
|
| 522 |
+
combined_content = f"{todo_content} (截止时间: {end_time}, 地点: {location})"
|
| 523 |
+
|
| 524 |
+
output_for_df = []
|
| 525 |
+
output_for_df.append([1, combined_content, "重要"])
|
| 526 |
+
|
| 527 |
+
return output_for_df
|
| 528 |
+
|
| 529 |
+
# 如果是物流取件类消息
|
| 530 |
+
elif ("快递" in text_input or "物流" in text_input or "取件" in text_input) and ("到达" in text_input or "取件码" in text_input or "柜" in text_input):
|
| 531 |
+
# 提取可能的取件码
|
| 532 |
+
pickup_code = None
|
| 533 |
+
code_pattern = r'取件码[是为:]?\s*(\d{4,6})'
|
| 534 |
+
code_match = re.search(code_pattern, text_input)
|
| 535 |
+
|
| 536 |
+
todo_content = "取快递"
|
| 537 |
+
if code_match:
|
| 538 |
+
pickup_code = code_match.group(1)
|
| 539 |
+
todo_content = f"取快递(取件码:{pickup_code})"
|
| 540 |
+
|
| 541 |
+
todo_item = {
|
| 542 |
+
message_id: {
|
| 543 |
+
"is_todo": True,
|
| 544 |
+
"end_time": (datetime.now(timezone.utc) + timedelta(days=2)).isoformat(),
|
| 545 |
+
"location": "线下:快递柜",
|
| 546 |
+
"todo_content": todo_content,
|
| 547 |
+
"urgency": "important"
|
| 548 |
+
}
|
| 549 |
+
}
|
| 550 |
+
|
| 551 |
+
# 转换为表格显示格式 - 合并为一行
|
| 552 |
+
end_time = todo_item[message_id]["end_time"].split("T")[0]
|
| 553 |
+
location = todo_item[message_id]["location"]
|
| 554 |
+
|
| 555 |
+
# 合并所有信息到任务内容中
|
| 556 |
+
combined_content = f"{todo_content} (截止时间: {end_time}, 地���: {location})"
|
| 557 |
+
|
| 558 |
+
output_for_df = []
|
| 559 |
+
output_for_df.append([1, combined_content, "重要"])
|
| 560 |
+
|
| 561 |
+
return output_for_df
|
| 562 |
+
|
| 563 |
+
# 对于其他类型的消息,调用LLM API进行处理
|
| 564 |
+
if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME:
|
| 565 |
+
logger.error("Filter API 配置不完整,无法调用 Filter LLM。")
|
| 566 |
+
return [["error", "Filter API configuration incomplete", "-"]]
|
| 567 |
+
|
| 568 |
+
headers = {
|
| 569 |
+
"Authorization": f"Bearer {Filter_API_KEY}",
|
| 570 |
+
"Accept": "application/json"
|
| 571 |
+
}
|
| 572 |
+
payload = {
|
| 573 |
+
"model": Filter_MODEL_NAME,
|
| 574 |
+
"messages": llm_messages,
|
| 575 |
+
"temperature": 0.2, # 降低温度以提高一致性
|
| 576 |
+
"top_p": 0.95,
|
| 577 |
+
"max_tokens": 1024,
|
| 578 |
+
"stream": False
|
| 579 |
+
}
|
| 580 |
+
|
| 581 |
+
api_url = f"{Filter_API_BASE_URL}/chat/completions"
|
| 582 |
+
|
| 583 |
+
try:
|
| 584 |
+
response = requests.post(api_url, headers=headers, json=payload)
|
| 585 |
+
response.raise_for_status() # 检查 HTTP 错误
|
| 586 |
+
raw_llm_response = response.json()['choices'][0]['message']['content']
|
| 587 |
+
logger.info(f"LLM 原始回复 (部分): {raw_llm_response[:200]}...")
|
| 588 |
+
except requests.exceptions.RequestException as e:
|
| 589 |
+
logger.error(f"调用 Filter API 失败: {e}")
|
| 590 |
+
return [["error", f"Filter API call failed: {e}", "-"]]
|
| 591 |
+
|
| 592 |
+
# 解析 LLM 响应
|
| 593 |
+
parsed_todos_data = json_parser(raw_llm_response)
|
| 594 |
+
|
| 595 |
+
if "error" in parsed_todos_data:
|
| 596 |
+
logger.error(f"解析 LLM 响应失败: {parsed_todos_data['error']}")
|
| 597 |
+
return [["error", f"LLM response parsing error: {parsed_todos_data['error']}", parsed_todos_data.get('raw_text', '')[:50] + "..."]]
|
| 598 |
+
|
| 599 |
+
# 处理解析后的数据
|
| 600 |
+
output_for_df = []
|
| 601 |
+
|
| 602 |
+
# 如果是字典格式(符合prompt模板输出格式)
|
| 603 |
+
if isinstance(parsed_todos_data, dict):
|
| 604 |
+
# 获取消息ID对应的待办信息
|
| 605 |
+
todo_info = None
|
| 606 |
+
for key, value in parsed_todos_data.items():
|
| 607 |
+
if key == message_id or key == str(message_id):
|
| 608 |
+
todo_info = value
|
| 609 |
+
break
|
| 610 |
+
|
| 611 |
+
if todo_info and isinstance(todo_info, dict) and todo_info.get("is_todo", False):
|
| 612 |
+
# 提取待办信息
|
| 613 |
+
todo_content = todo_info.get("todo_content", "未指定待办内容")
|
| 614 |
+
end_time = todo_info.get("end_time")
|
| 615 |
+
location = todo_info.get("location")
|
| 616 |
+
urgency = todo_info.get("urgency", "unimportant")
|
| 617 |
+
|
| 618 |
+
# 准备合并显示的内容
|
| 619 |
+
combined_content = todo_content
|
| 620 |
+
|
| 621 |
+
# 添加截止时间
|
| 622 |
+
if end_time and end_time != "null":
|
| 623 |
+
try:
|
| 624 |
+
date_part = end_time.split("T")[0] if "T" in end_time else end_time
|
| 625 |
+
combined_content += f" (截止时间: {date_part}"
|
| 626 |
+
except:
|
| 627 |
+
combined_content += f" (截止时间: {end_time}"
|
| 628 |
+
else:
|
| 629 |
+
combined_content += " ("
|
| 630 |
+
|
| 631 |
+
# 添加地点
|
| 632 |
+
if location and location != "null":
|
| 633 |
+
combined_content += f", 地点: {location})"
|
| 634 |
+
else:
|
| 635 |
+
combined_content += ")"
|
| 636 |
+
|
| 637 |
+
# 添加紧急程度
|
| 638 |
+
urgency_display = "一般"
|
| 639 |
+
if urgency == "urgent":
|
| 640 |
+
urgency_display = "紧急"
|
| 641 |
+
elif urgency == "important":
|
| 642 |
+
urgency_display = "重要"
|
| 643 |
+
|
| 644 |
+
# 创建单行输出
|
| 645 |
+
output_for_df = []
|
| 646 |
+
output_for_df.append([1, combined_content, urgency_display])
|
| 647 |
+
else:
|
| 648 |
+
# 不是待办事项
|
| 649 |
+
output_for_df = []
|
| 650 |
+
output_for_df.append([1, "此消息不包含待办事项", "-"])
|
| 651 |
+
|
| 652 |
+
# 如果是旧格式(列表格式)
|
| 653 |
+
elif isinstance(parsed_todos_data, list):
|
| 654 |
+
output_for_df = []
|
| 655 |
+
|
| 656 |
+
# 检查列表是否为空
|
| 657 |
+
if not parsed_todos_data:
|
| 658 |
+
logger.warning("LLM 返回了空列表,无法生成 ToDo 项目")
|
| 659 |
+
return [[1, "未能生成待办事项", "-"]]
|
| 660 |
+
|
| 661 |
+
for i, item in enumerate(parsed_todos_data):
|
| 662 |
+
if isinstance(item, dict):
|
| 663 |
+
todo_content = item.get('todo_content', item.get('content', 'N/A'))
|
| 664 |
+
status = item.get('status', '未完成')
|
| 665 |
+
urgency = item.get('urgency', 'normal')
|
| 666 |
+
|
| 667 |
+
# 合并所有信息到一行
|
| 668 |
+
combined_content = todo_content
|
| 669 |
+
|
| 670 |
+
# 添加截止时间
|
| 671 |
+
if 'end_time' in item and item['end_time']:
|
| 672 |
+
try:
|
| 673 |
+
if isinstance(item['end_time'], str):
|
| 674 |
+
date_part = item['end_time'].split("T")[0] if "T" in item['end_time'] else item['end_time']
|
| 675 |
+
combined_content += f" (截止时间: {date_part}"
|
| 676 |
+
else:
|
| 677 |
+
combined_content += f" (截止时间: {str(item['end_time'])}"
|
| 678 |
+
except Exception as e:
|
| 679 |
+
logger.warning(f"处理end_time时出错: {e}")
|
| 680 |
+
combined_content += " ("
|
| 681 |
+
else:
|
| 682 |
+
combined_content += " ("
|
| 683 |
+
|
| 684 |
+
# 添加地点
|
| 685 |
+
if 'location' in item and item['location']:
|
| 686 |
+
combined_content += f", 地点: {item['location']})"
|
| 687 |
+
else:
|
| 688 |
+
combined_content += ")"
|
| 689 |
+
|
| 690 |
+
# 设置重要等级
|
| 691 |
+
importance = "一般"
|
| 692 |
+
if urgency == "urgent":
|
| 693 |
+
importance = "紧急"
|
| 694 |
+
elif urgency == "important":
|
| 695 |
+
importance = "重要"
|
| 696 |
+
|
| 697 |
+
output_for_df.append([i + 1, combined_content, importance])
|
| 698 |
+
else:
|
| 699 |
+
# 如果不是字典,转换为字符串并添加到列表
|
| 700 |
+
try:
|
| 701 |
+
item_str = str(item) if item is not None else "未知项目"
|
| 702 |
+
output_for_df.append([i + 1, item_str, "一般"])
|
| 703 |
+
except Exception as e:
|
| 704 |
+
logger.warning(f"处理非字典项目时出错: {e}")
|
| 705 |
+
output_for_df.append([i + 1, "处理错误的项目", "一般"])
|
| 706 |
+
|
| 707 |
+
if not output_for_df:
|
| 708 |
+
logger.info("LLM 解析结果为空或无法转换为DataFrame格式。")
|
| 709 |
+
return [["info", "未发现待办事项", "-"]]
|
| 710 |
+
|
| 711 |
+
return output_for_df
|
| 712 |
+
|
| 713 |
+
except Exception as e:
|
| 714 |
+
logger.exception(f"调用 LLM 或解析时发生错误 (generate_todolist_from_text)")
|
| 715 |
+
return [["error", f"LLM call/parse error: {str(e)}", "-"]]
|
| 716 |
+
|
| 717 |
+
#gradio
|
| 718 |
+
def process(audio, image, request: gr.Request):
|
| 719 |
+
"""处理语音和图片的示例函数"""
|
| 720 |
+
# 获取并记录客户端IP
|
| 721 |
+
client_ip = get_client_ip(request, True)
|
| 722 |
+
print(f"Processing audio/image request from IP: {client_ip}")
|
| 723 |
+
|
| 724 |
+
if audio is not None:
|
| 725 |
+
sample_rate, audio_data = audio
|
| 726 |
+
audio_info = f"音频采样率: {sample_rate}Hz, 数据长度: {len(audio_data)}"
|
| 727 |
+
else:
|
| 728 |
+
audio_info = "未收到音频"
|
| 729 |
+
|
| 730 |
+
if image is not None:
|
| 731 |
+
image_info = f"图片尺寸: {image.shape}"
|
| 732 |
+
else:
|
| 733 |
+
image_info = "未收到图片"
|
| 734 |
+
|
| 735 |
+
return audio_info, image_info
|
| 736 |
+
|
| 737 |
+
def respond(
|
| 738 |
+
message,
|
| 739 |
+
history: list[tuple[str, str]],
|
| 740 |
+
system_message,
|
| 741 |
+
max_tokens,
|
| 742 |
+
temperature,
|
| 743 |
+
top_p,
|
| 744 |
+
audio, # 多模态输入:音频
|
| 745 |
+
image # 多模态输入:图片
|
| 746 |
+
):
|
| 747 |
+
# ... (聊天回复逻辑基本保持不变, 但确保 client 使用的是配置好的 HF client)
|
| 748 |
+
# 1. 多模态处理接口 (其他人负责)
|
| 749 |
+
# processed_text_from_multimodal = multimodal_placeholder_function(audio, image)
|
| 750 |
+
# 多模态处理:调用讯飞API进行语音和图像识别
|
| 751 |
+
multimodal_content = ""
|
| 752 |
+
|
| 753 |
+
# 多模态处理配置已移至具体处理部分
|
| 754 |
+
|
| 755 |
+
if audio is not None:
|
| 756 |
+
try:
|
| 757 |
+
audio_sample_rate, audio_data = audio
|
| 758 |
+
multimodal_content += f"\n[音频信息: 采样率 {audio_sample_rate}Hz, 时长 {len(audio_data)/audio_sample_rate:.2f}秒]"
|
| 759 |
+
|
| 760 |
+
# 调用Azure Speech语音识别
|
| 761 |
+
azure_speech_config = get_hf_azure_speech_config()
|
| 762 |
+
azure_speech_key = azure_speech_config.get('key')
|
| 763 |
+
azure_speech_region = azure_speech_config.get('region')
|
| 764 |
+
|
| 765 |
+
if azure_speech_key and azure_speech_region:
|
| 766 |
+
import tempfile
|
| 767 |
+
import soundfile as sf
|
| 768 |
+
import os
|
| 769 |
+
|
| 770 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
|
| 771 |
+
sf.write(temp_audio.name, audio_data, audio_sample_rate)
|
| 772 |
+
temp_audio_path = temp_audio.name
|
| 773 |
+
|
| 774 |
+
audio_text = azure_speech_to_text(azure_speech_key, azure_speech_region, temp_audio_path)
|
| 775 |
+
if audio_text:
|
| 776 |
+
multimodal_content += f"\n[语音识别结果: {audio_text}]"
|
| 777 |
+
else:
|
| 778 |
+
multimodal_content += "\n[语音识别失败]"
|
| 779 |
+
|
| 780 |
+
os.unlink(temp_audio_path)
|
| 781 |
+
else:
|
| 782 |
+
multimodal_content += "\n[Azure Speech API配置不完整,无法进行语音识别]"
|
| 783 |
+
|
| 784 |
+
except Exception as e:
|
| 785 |
+
multimodal_content += f"\n[音频处理错误: {str(e)}]"
|
| 786 |
+
|
| 787 |
+
if image is not None:
|
| 788 |
+
try:
|
| 789 |
+
multimodal_content += f"\n[图片信息: 尺寸 {image.shape}]"
|
| 790 |
+
|
| 791 |
+
# 调用讯飞图像识别
|
| 792 |
+
if xunfei_appid and xunfei_apikey and xunfei_apisecret:
|
| 793 |
+
import tempfile
|
| 794 |
+
from PIL import Image
|
| 795 |
+
import os
|
| 796 |
+
|
| 797 |
+
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_image:
|
| 798 |
+
if len(image.shape) == 3: # RGB图像
|
| 799 |
+
pil_image = Image.fromarray(image.astype('uint8'), 'RGB')
|
| 800 |
+
else: # 灰度图像
|
| 801 |
+
pil_image = Image.fromarray(image.astype('uint8'), 'L')
|
| 802 |
+
|
| 803 |
+
pil_image.save(temp_image.name, 'JPEG')
|
| 804 |
+
temp_image_path = temp_image.name
|
| 805 |
+
|
| 806 |
+
image_text = image_to_str(endpoint="https://ai-siyuwang5414995ai361208251338.cognitiveservices.azure.com/", key="45PYY2Av9CdMCveAjVG43MGKrnHzSxdiFTK9mWBgrOsMAHavxKj0JQQJ99BDACHYHv6XJ3w3AAAAACOGeVpQ", unused_param=None, file_path=temp_image_path)
|
| 807 |
+
if image_text:
|
| 808 |
+
multimodal_content += f"\n[图像识别结果: {image_text}]"
|
| 809 |
+
else:
|
| 810 |
+
multimodal_content += "\n[图像识别失败]"
|
| 811 |
+
|
| 812 |
+
os.unlink(temp_image_path)
|
| 813 |
+
else:
|
| 814 |
+
multimodal_content += "\n[讯飞API配置不完整,无法进行图像识别]"
|
| 815 |
+
|
| 816 |
+
except Exception as e:
|
| 817 |
+
multimodal_content += f"\n[图像处理错误: {str(e)}]"
|
| 818 |
+
|
| 819 |
+
# 将多模态内容(或其处理结果)与用户文本消息结合
|
| 820 |
+
# combined_message = message
|
| 821 |
+
# if multimodal_content: # 如果有多模态内容,则附加
|
| 822 |
+
# combined_message += "\n" + multimodal_content
|
| 823 |
+
# 为了聊天模型的连贯性,聊天部分可能只使用原始 message
|
| 824 |
+
# 而 ToDoList 生成则使用 combined_message
|
| 825 |
+
|
| 826 |
+
# 聊天回复生成
|
| 827 |
+
chat_messages = [{"role": "system", "content": system_message}]
|
| 828 |
+
for val in history:
|
| 829 |
+
if val[0]:
|
| 830 |
+
chat_messages.append({"role": "user", "content": val[0]})
|
| 831 |
+
if val[1]:
|
| 832 |
+
chat_messages.append({"role": "assistant", "content": val[1]})
|
| 833 |
+
chat_messages.append({"role": "user", "content": message}) # 聊天机器人使用原始消息
|
| 834 |
+
|
| 835 |
+
chat_response_stream = ""
|
| 836 |
+
if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME:
|
| 837 |
+
logger.error("Filter API 配置不完整,无法调用 LLM333。")
|
| 838 |
+
yield "Filter API 配置不完整,无法提供聊天回复。", []
|
| 839 |
+
return
|
| 840 |
+
|
| 841 |
+
headers = {
|
| 842 |
+
"Authorization": f"Bearer {Filter_API_KEY}",
|
| 843 |
+
"Accept": "application/json"
|
| 844 |
+
}
|
| 845 |
+
payload = {
|
| 846 |
+
"model": Filter_MODEL_NAME,
|
| 847 |
+
"messages": chat_messages,
|
| 848 |
+
"temperature": temperature,
|
| 849 |
+
"top_p": top_p,
|
| 850 |
+
"max_tokens": max_tokens,
|
| 851 |
+
"stream": True # 聊天通常需要流式输出
|
| 852 |
+
}
|
| 853 |
+
api_url = f"{Filter_API_BASE_URL}/chat/completions"
|
| 854 |
+
|
| 855 |
+
try:
|
| 856 |
+
response = requests.post(api_url, headers=headers, json=payload, stream=True)
|
| 857 |
+
response.raise_for_status() # 检查 HTTP 错误
|
| 858 |
+
|
| 859 |
+
for chunk in response.iter_content(chunk_size=None):
|
| 860 |
+
if chunk:
|
| 861 |
+
try:
|
| 862 |
+
# NVIDIA API 的流式输出是 SSE 格式,需要解析
|
| 863 |
+
# 每一行以 'data: ' 开头,后面是 JSON
|
| 864 |
+
for line in chunk.decode('utf-8').splitlines():
|
| 865 |
+
if line.startswith('data: '):
|
| 866 |
+
json_data = line[len('data: '):]
|
| 867 |
+
if json_data.strip() == '[DONE]':
|
| 868 |
+
break
|
| 869 |
+
data = json.loads(json_data)
|
| 870 |
+
# 检查 choices 列表是否存在且不为空
|
| 871 |
+
if 'choices' in data and len(data['choices']) > 0:
|
| 872 |
+
token = data['choices'][0]['delta'].get('content', '')
|
| 873 |
+
if token:
|
| 874 |
+
chat_response_stream += token
|
| 875 |
+
yield chat_response_stream, []
|
| 876 |
+
except json.JSONDecodeError:
|
| 877 |
+
logger.warning(f"无法解析流式响应块: {chunk.decode('utf-8')}")
|
| 878 |
+
except Exception as e:
|
| 879 |
+
logger.error(f"处理流式响应时发生错误: {e}")
|
| 880 |
+
yield chat_response_stream + f"\n\n错误: {e}", []
|
| 881 |
+
|
| 882 |
+
except requests.exceptions.RequestException as e:
|
| 883 |
+
logger.error(f"调用 NVIDIA API 失败: {e}")
|
| 884 |
+
yield f"调用 NVIDIA API 失败: {e}", []
|
| 885 |
+
|
| 886 |
+
# 全局变量存储所有待办事项
|
| 887 |
+
all_todos_global = []
|
| 888 |
+
|
| 889 |
+
# 创建自定义的聊天界面
|
| 890 |
+
with gr.Blocks() as app:
|
| 891 |
+
gr.Markdown("# ToDoAgent Multi-Modal Interface with ToDo List")
|
| 892 |
+
|
| 893 |
+
with gr.Row():
|
| 894 |
+
with gr.Column(scale=2):
|
| 895 |
+
gr.Markdown("## Chat Interface")
|
| 896 |
+
chatbot = gr.Chatbot(height=450, label="聊天记录", type="messages") # 推荐使用 type="messages"
|
| 897 |
+
msg = gr.Textbox(label="输入消息", placeholder="输入您的问题或待办事项...")
|
| 898 |
+
|
| 899 |
+
with gr.Row():
|
| 900 |
+
audio_input = gr.Audio(label="上传语音", type="numpy", sources=["upload", "microphone"])
|
| 901 |
+
image_input = gr.Image(label="上传图片", type="numpy")
|
| 902 |
+
|
| 903 |
+
with gr.Accordion("高级设置", open=False):
|
| 904 |
+
system_msg = gr.Textbox(value="You are a friendly Chatbot.", label="系统提示")
|
| 905 |
+
max_tokens_slider = gr.Slider(minimum=1, maximum=4096, value=1024, step=1, label="最大生成长度(聊天)") # 增加聊天模型参数范围
|
| 906 |
+
temperature_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="温度(聊天)")
|
| 907 |
+
top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p(聊天)")
|
| 908 |
+
|
| 909 |
+
with gr.Row():
|
| 910 |
+
submit_btn = gr.Button("发送", variant="primary")
|
| 911 |
+
clear_btn = gr.Button("清除聊天和ToDo")
|
| 912 |
+
|
| 913 |
+
with gr.Column(scale=1):
|
| 914 |
+
gr.Markdown("## Generated ToDo List")
|
| 915 |
+
todolist_df = gr.DataFrame(headers=["ID", "任务内容", "状态"],
|
| 916 |
+
datatype=["number", "str", "str"],
|
| 917 |
+
row_count=(0, "dynamic"),
|
| 918 |
+
col_count=(3, "fixed"),
|
| 919 |
+
label="待办事项列表")
|
| 920 |
+
|
| 921 |
+
def user(user_message, chat_history):
|
| 922 |
+
# 将用户消息添加到聊天记录 (Gradio type="messages" 格式)
|
| 923 |
+
if not chat_history: chat_history = []
|
| 924 |
+
chat_history.append({"role": "user", "content": user_message})
|
| 925 |
+
return "", chat_history
|
| 926 |
+
|
| 927 |
+
def bot_interaction(chat_history, system_message, max_tokens, temperature, top_p, audio, image):
|
| 928 |
+
user_message_for_chat = ""
|
| 929 |
+
if chat_history and chat_history[-1]["role"] == "user":
|
| 930 |
+
user_message_for_chat = chat_history[-1]["content"]
|
| 931 |
+
|
| 932 |
+
# 准备用于 ToDoList 生成的输入文本 (多模态部分由其他人负责)
|
| 933 |
+
text_for_todolist = user_message_for_chat
|
| 934 |
+
# 可以在这里添加从 audio/image 提取文本的逻辑,并附加到 text_for_todolist
|
| 935 |
+
# multimodal_text = process_multimodal_inputs(audio, image) # 假设的函数
|
| 936 |
+
# if multimodal_text:
|
| 937 |
+
# text_for_todolist += "\n" + multimodal_text
|
| 938 |
+
|
| 939 |
+
# 1. 生成聊天回复 (流式)
|
| 940 |
+
# 转换 chat_history 从 [{'role':'user', 'content':'...'}, ...] 到 [('user_msg', 'bot_msg'), ...]
|
| 941 |
+
# respond 函数期望的是 history: list[tuple[str, str]]
|
| 942 |
+
# 但 Gradio type="messages" 的 chatbot.value 是 [{'role': ..., 'content': ...}, ...]
|
| 943 |
+
# 需要转换
|
| 944 |
+
formatted_history_for_respond = []
|
| 945 |
+
temp_user_msg = None
|
| 946 |
+
for item in chat_history[:-1]: #排除最后一条用户消息,因为它会作为当前message传入respond
|
| 947 |
+
if item["role"] == "user":
|
| 948 |
+
temp_user_msg = item["content"]
|
| 949 |
+
elif item["role"] == "assistant" and temp_user_msg is not None:
|
| 950 |
+
formatted_history_for_respond.append((temp_user_msg, item["content"]))
|
| 951 |
+
temp_user_msg = None
|
| 952 |
+
elif item["role"] == "assistant" and temp_user_msg is None: # Bot 先说话的情况
|
| 953 |
+
formatted_history_for_respond.append(("", item["content"]))
|
| 954 |
+
|
| 955 |
+
chat_stream_generator = respond(
|
| 956 |
+
user_message_for_chat,
|
| 957 |
+
formatted_history_for_respond, # 传递转换后的历史
|
| 958 |
+
system_message,
|
| 959 |
+
max_tokens,
|
| 960 |
+
temperature,
|
| 961 |
+
top_p,
|
| 962 |
+
audio,
|
| 963 |
+
image
|
| 964 |
+
)
|
| 965 |
+
|
| 966 |
+
full_chat_response = ""
|
| 967 |
+
current_todos = []
|
| 968 |
+
|
| 969 |
+
for chat_response_part, _ in chat_stream_generator:
|
| 970 |
+
full_chat_response = chat_response_part
|
| 971 |
+
# 更新 chat_history (Gradio type="messages" 格式)
|
| 972 |
+
if chat_history and chat_history[-1]["role"] == "user":
|
| 973 |
+
# 如果最后一条是用户消息,添加机器人回复
|
| 974 |
+
# 但由于是流式,我们可能需要先添加一个空的 assistant 消息,然后更新它
|
| 975 |
+
# 或者,等待流结束后一次性添加
|
| 976 |
+
# 为了简化,我们先假设 respond 返回的是完整回复,或者在循环外更新
|
| 977 |
+
pass # 流式更新 chatbot 在 submit_btn.click 中处理
|
| 978 |
+
yield chat_history + [[None, full_chat_response]], current_todos # 临时做法,需要适配Gradio的流式更新
|
| 979 |
+
|
| 980 |
+
# 流式结束后,更新 chat_history 中的最后一条 assistant 消息
|
| 981 |
+
if chat_history and full_chat_response:
|
| 982 |
+
# 查找最后一条用户消息,在其后添加或更新机器人回复
|
| 983 |
+
# 这种方式对于 type="messages" 更友好
|
| 984 |
+
# 实际上,Gradio 的 chatbot 更新应该在 .then() 中处理,这里先模拟
|
| 985 |
+
# chat_history.append({"role": "assistant", "content": full_chat_response})
|
| 986 |
+
# 这个 yield 应该在 submit_btn.click 的 .then() 中处理 chatbot 的更新
|
| 987 |
+
# 这里我们先专注于 ToDo 生成
|
| 988 |
+
pass # chatbot 更新由 Gradio 机制处理
|
| 989 |
+
|
| 990 |
+
# 2. 聊天回复完成后,生成/更新 ToDoList
|
| 991 |
+
if text_for_todolist:
|
| 992 |
+
# 使用一个唯一的 ID,例如基于时间戳或随机数,如果需要区分不同输入的 ToDo
|
| 993 |
+
message_id_for_todo = f"hf_app_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
|
| 994 |
+
new_todo_items = generate_todolist_from_text(text_for_todolist, message_id_for_todo)
|
| 995 |
+
current_todos = new_todo_items
|
| 996 |
+
|
| 997 |
+
# bot_interaction 应该返回 chatbot 的最终状态和 todolist_df 的数据
|
| 998 |
+
# chatbot 的最终状态是 chat_history + assistant 的回复
|
| 999 |
+
final_chat_history = list(chat_history) # 复制
|
| 1000 |
+
if full_chat_response:
|
| 1001 |
+
final_chat_history.append({"role": "assistant", "content": full_chat_response})
|
| 1002 |
+
|
| 1003 |
+
yield final_chat_history, current_todos
|
| 1004 |
+
|
| 1005 |
+
# 连接事件 (适配 type="messages")
|
| 1006 |
+
# Gradio 的流式更新通常是:
|
| 1007 |
+
# 1. user 函数准备输入,返回 (空输入框, 更新后的聊天记录)
|
| 1008 |
+
# 2. bot_interaction 函数是一个生成器,yield (部分聊天记录, 部分ToDo)
|
| 1009 |
+
# msg.submit 和 submit_btn.click 的 outputs 需要对应 bot_interaction 的 yield
|
| 1010 |
+
|
| 1011 |
+
# 简化版,非流式更新 chatbot,流式更新由 respond 内部的 yield 控制
|
| 1012 |
+
# 但 respond 的 yield 格式 (str, list) 与 bot_interaction (list, list) 不同
|
| 1013 |
+
# 需要调整 respond 的 yield 或 bot_interaction 的处理
|
| 1014 |
+
|
| 1015 |
+
# 调整后的事件处理,以更好地支持流式聊天和ToDo更新
|
| 1016 |
+
def process_filtered_result_for_todo(filtered_result, content, source_type):
|
| 1017 |
+
"""处理过滤结果并生成todolist的辅助函数"""
|
| 1018 |
+
todos = []
|
| 1019 |
+
|
| 1020 |
+
if isinstance(filtered_result, dict) and "error" in filtered_result:
|
| 1021 |
+
logger.error(f"{source_type} Filter 模块处理失败: {filtered_result['error']}")
|
| 1022 |
+
todos = [["Error", f"{source_type}: {filtered_result['error']}", "Filter Failed"]]
|
| 1023 |
+
elif isinstance(filtered_result, dict) and filtered_result.get("分类") == "其他":
|
| 1024 |
+
logger.info(f"{source_type}消息被 Filter 模块归类为 '其他',不生成 ToDo List。")
|
| 1025 |
+
todos = [["Info", f"{source_type}: 消息被归类为 '其他',无需生成 ToDo。", "Filtered"]]
|
| 1026 |
+
elif isinstance(filtered_result, list):
|
| 1027 |
+
# 处理列表类型的过滤结果
|
| 1028 |
+
category = None
|
| 1029 |
+
if filtered_result:
|
| 1030 |
+
for item in filtered_result:
|
| 1031 |
+
if isinstance(item, dict) and "分类" in item:
|
| 1032 |
+
category = item["分类"]
|
| 1033 |
+
break
|
| 1034 |
+
|
| 1035 |
+
if category == "其他":
|
| 1036 |
+
logger.info(f"{source_type}消息被 Filter 模块归类为 '其他',不生成 ToDo List。")
|
| 1037 |
+
todos = [["Info", f"{source_type}: 消息被归类为 '其他',无需生成 ToDo。", "Filtered"]]
|
| 1038 |
+
else:
|
| 1039 |
+
logger.info(f"{source_type}消息被 Filter 模块归类为 '{category if category else '未知'}',继续生成 ToDo List。")
|
| 1040 |
+
if content:
|
| 1041 |
+
msg_id_todo = f"hf_app_todo_{source_type}_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
|
| 1042 |
+
todos = generate_todolist_from_text(content, msg_id_todo)
|
| 1043 |
+
# 为每个todo添加来源标识
|
| 1044 |
+
for todo in todos:
|
| 1045 |
+
if len(todo) > 1:
|
| 1046 |
+
todo[1] = f"[{source_type}] {todo[1]}"
|
| 1047 |
+
else:
|
| 1048 |
+
# 如果是字典但不是"其他"分类
|
| 1049 |
+
logger.info(f"{source_type}消息被 Filter 模块归类为 '{filtered_result.get('分类') if isinstance(filtered_result, dict) else '未知'}',继续生成 ToDo List。")
|
| 1050 |
+
if content:
|
| 1051 |
+
msg_id_todo = f"hf_app_todo_{source_type}_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
|
| 1052 |
+
todos = generate_todolist_from_text(content, msg_id_todo)
|
| 1053 |
+
# 为每个todo添加来源标识
|
| 1054 |
+
for todo in todos:
|
| 1055 |
+
if len(todo) > 1:
|
| 1056 |
+
todo[1] = f"[{source_type}] {todo[1]}"
|
| 1057 |
+
|
| 1058 |
+
return todos
|
| 1059 |
+
|
| 1060 |
+
def handle_submit(user_msg_content, ch_history, sys_msg, max_t, temp, t_p, audio_f, image_f, request: gr.Request):
|
| 1061 |
+
global all_todos_global
|
| 1062 |
+
|
| 1063 |
+
# 获取并记录客户端IP
|
| 1064 |
+
client_ip = get_client_ip(request, True)
|
| 1065 |
+
print(f"Processing request from IP: {client_ip}")
|
| 1066 |
+
|
| 1067 |
+
# 首先处理多模态输入,获取多模态内容
|
| 1068 |
+
multimodal_text_content = ""
|
| 1069 |
+
# 添加调试日志
|
| 1070 |
+
logger.info(f"开始多模态处理 - 音频: {audio_f is not None}, 图像: {image_f is not None}")
|
| 1071 |
+
|
| 1072 |
+
# 获取Azure Speech配置
|
| 1073 |
+
azure_speech_config = get_hf_azure_speech_config()
|
| 1074 |
+
azure_speech_key = azure_speech_config.get('key')
|
| 1075 |
+
azure_speech_region = azure_speech_config.get('region')
|
| 1076 |
+
|
| 1077 |
+
# 添加调试日志
|
| 1078 |
+
logger.info(f"Azure Speech配置状态 - key: {bool(azure_speech_key)}, region: {bool(azure_speech_region)}")
|
| 1079 |
+
|
| 1080 |
+
# 处理音频输入(使用Azure Speech服务)
|
| 1081 |
+
if audio_f is not None and azure_speech_key and azure_speech_region:
|
| 1082 |
+
logger.info("开始处理音频输入...")
|
| 1083 |
+
try:
|
| 1084 |
+
audio_sample_rate, audio_data = audio_f
|
| 1085 |
+
logger.info(f"音频信息: 采样率 {audio_sample_rate}Hz, 数据长度 {len(audio_data)}")
|
| 1086 |
+
|
| 1087 |
+
# 保存音频为.wav文件
|
| 1088 |
+
audio_filename = os.path.join(SAVE_DIR, f"audio_{client_ip}.wav")
|
| 1089 |
+
save_audio(audio_f, audio_filename)
|
| 1090 |
+
logger.info(f"音频已保存: {audio_filename}")
|
| 1091 |
+
|
| 1092 |
+
# 调用Azure Speech服务处理音频
|
| 1093 |
+
audio_text = azure_speech_to_text(azure_speech_key, azure_speech_region, audio_filename)
|
| 1094 |
+
logger.info(f"音频识别结果: {audio_text}")
|
| 1095 |
+
if audio_text:
|
| 1096 |
+
multimodal_text_content += f"音频内容: {audio_text}"
|
| 1097 |
+
logger.info("音频处理完成")
|
| 1098 |
+
else:
|
| 1099 |
+
logger.warning("音频处理失败")
|
| 1100 |
+
except Exception as e:
|
| 1101 |
+
logger.error(f"音频处理错误: {str(e)}")
|
| 1102 |
+
elif audio_f is not None:
|
| 1103 |
+
logger.warning("音频文件存在但Azure Speech配置不完整,跳过音频处理")
|
| 1104 |
+
|
| 1105 |
+
# 处理图像输入(使用Azure Computer Vision服务)
|
| 1106 |
+
if image_f is not None:
|
| 1107 |
+
logger.info("开始处理图像输入...")
|
| 1108 |
+
try:
|
| 1109 |
+
logger.info(f"图像信息: 形状 {image_f.shape}, 数据类型 {image_f.dtype}")
|
| 1110 |
+
|
| 1111 |
+
# 保存图片为.jpg文件
|
| 1112 |
+
image_filename = os.path.join(SAVE_DIR, f"image_{client_ip}.jpg")
|
| 1113 |
+
save_image(image_f, image_filename)
|
| 1114 |
+
logger.info(f"图像已保存: {image_filename}")
|
| 1115 |
+
|
| 1116 |
+
# 调用tools.py中的image_to_str方法处理图片
|
| 1117 |
+
image_text = image_to_str(endpoint="https://ai-siyuwang5414995ai361208251338.cognitiveservices.azure.com/", key="45PYY2Av9CdMCveAjVG43MGKrnHzSxdiFTK9mWBgrOsMAHavxKj0JQQJ99BDACHYHv6XJ3w3AAAAACOGeVpQ", unused_param=None, file_path=image_filename)
|
| 1118 |
+
logger.info(f"图像识别结果: {image_text}")
|
| 1119 |
+
if image_text:
|
| 1120 |
+
if multimodal_text_content: # 如果已有音频内容,添加分隔符
|
| 1121 |
+
multimodal_text_content += "\n"
|
| 1122 |
+
multimodal_text_content += f"图像内容: {image_text}"
|
| 1123 |
+
logger.info("图像处理完成")
|
| 1124 |
+
else:
|
| 1125 |
+
logger.warning("图像处理失败")
|
| 1126 |
+
except Exception as e:
|
| 1127 |
+
logger.error(f"图像处理错误: {str(e)}")
|
| 1128 |
+
elif image_f is not None:
|
| 1129 |
+
logger.warning("图像文件存在但处理失败,跳过图像处理")
|
| 1130 |
+
|
| 1131 |
+
# 确定最终的用户输入内容:如果用户没有输入文本,使用多模态识别的内容
|
| 1132 |
+
final_user_content = user_msg_content.strip() if user_msg_content else ""
|
| 1133 |
+
if not final_user_content and multimodal_text_content:
|
| 1134 |
+
final_user_content = multimodal_text_content
|
| 1135 |
+
logger.info(f"用户无文本输入,使用多模态内容作为用户输入: {final_user_content}")
|
| 1136 |
+
elif final_user_content and multimodal_text_content:
|
| 1137 |
+
# 用户有文本输入,多模态内容作为补充
|
| 1138 |
+
final_user_content = f"{final_user_content}\n{multimodal_text_content}"
|
| 1139 |
+
logger.info(f"用户有文本输入,多模态内容作为补充")
|
| 1140 |
+
|
| 1141 |
+
# 如果最终还是没有任何内容,提供默认提示
|
| 1142 |
+
if not final_user_content:
|
| 1143 |
+
final_user_content = "[无输入内容]"
|
| 1144 |
+
logger.warning("用户没有提供任何输入内容(文本、音频或图像)")
|
| 1145 |
+
|
| 1146 |
+
logger.info(f"最终用户输入内容: {final_user_content}")
|
| 1147 |
+
|
| 1148 |
+
# 1. 更新聊天记录 (用户部分) - 使用最终确定的用户内容
|
| 1149 |
+
if not ch_history: ch_history = []
|
| 1150 |
+
ch_history.append({"role": "user", "content": final_user_content})
|
| 1151 |
+
yield ch_history, [] # 更新聊天,ToDo 列表暂时不变
|
| 1152 |
+
|
| 1153 |
+
# 2. 流式生成机器人回复并更新聊天记录
|
| 1154 |
+
# 转换 chat_history 为 respond 函数期望的格式
|
| 1155 |
+
formatted_hist_for_respond = []
|
| 1156 |
+
temp_user_msg_for_hist = None
|
| 1157 |
+
# 使用 ch_history[:-1] 因为当前用户消息已在 ch_history 中
|
| 1158 |
+
for item_hist in ch_history[:-1]:
|
| 1159 |
+
if item_hist["role"] == "user":
|
| 1160 |
+
temp_user_msg_for_hist = item_hist["content"]
|
| 1161 |
+
elif item_hist["role"] == "assistant" and temp_user_msg_for_hist is not None:
|
| 1162 |
+
formatted_hist_for_respond.append((temp_user_msg_for_hist, item_hist["content"]))
|
| 1163 |
+
temp_user_msg_for_hist = None
|
| 1164 |
+
elif item_hist["role"] == "assistant" and temp_user_msg_for_hist is None:
|
| 1165 |
+
formatted_hist_for_respond.append(("", item_hist["content"]))
|
| 1166 |
+
|
| 1167 |
+
# 准备一个 assistant 消息的槽位
|
| 1168 |
+
ch_history.append({"role": "assistant", "content": ""})
|
| 1169 |
+
|
| 1170 |
+
full_bot_response = ""
|
| 1171 |
+
# 使用最终确定的用户内容进行对话
|
| 1172 |
+
for bot_response_token, _ in respond(final_user_content, formatted_hist_for_respond, sys_msg, max_t, temp, t_p, audio_f, image_f):
|
| 1173 |
+
full_bot_response = bot_response_token
|
| 1174 |
+
ch_history[-1]["content"] = full_bot_response # 更新最后一条 assistant 消息
|
| 1175 |
+
yield ch_history, [] # 流式更新聊天,ToDo 列表不变
|
| 1176 |
+
|
| 1177 |
+
# 3. 生成 ToDoList - 分别处理音频、图片和文字输入
|
| 1178 |
+
new_todos_list = []
|
| 1179 |
+
|
| 1180 |
+
# 分别处理文字输入
|
| 1181 |
+
if user_msg_content.strip():
|
| 1182 |
+
logger.info(f"处理文字输入生成ToDo: {user_msg_content.strip()}")
|
| 1183 |
+
text_filtered_result = filter_message_with_llm(user_msg_content.strip())
|
| 1184 |
+
text_todos = process_filtered_result_for_todo(text_filtered_result, user_msg_content.strip(), "文字")
|
| 1185 |
+
new_todos_list.extend(text_todos)
|
| 1186 |
+
|
| 1187 |
+
# 分别处理音频输入
|
| 1188 |
+
if audio_f is not None and azure_speech_key and azure_speech_region:
|
| 1189 |
+
try:
|
| 1190 |
+
audio_sample_rate, audio_data = audio_f
|
| 1191 |
+
audio_filename = os.path.join(SAVE_DIR, f"audio_{client_ip}.wav")
|
| 1192 |
+
save_audio(audio_f, audio_filename)
|
| 1193 |
+
audio_text = azure_speech_to_text(azure_speech_key, azure_speech_region, audio_filename)
|
| 1194 |
+
if audio_text:
|
| 1195 |
+
logger.info(f"处理音频输入生成ToDo: {audio_text}")
|
| 1196 |
+
audio_filtered_result = filter_message_with_llm(audio_text)
|
| 1197 |
+
audio_todos = process_filtered_result_for_todo(audio_filtered_result, audio_text, "音频")
|
| 1198 |
+
new_todos_list.extend(audio_todos)
|
| 1199 |
+
except Exception as e:
|
| 1200 |
+
logger.error(f"音频处理错误: {str(e)}")
|
| 1201 |
+
|
| 1202 |
+
# 分别处理图片输入
|
| 1203 |
+
if image_f is not None:
|
| 1204 |
+
try:
|
| 1205 |
+
image_filename = os.path.join(SAVE_DIR, f"image_{client_ip}.jpg")
|
| 1206 |
+
save_image(image_f, image_filename)
|
| 1207 |
+
image_text = image_to_str(endpoint="https://ai-siyuwang5414995ai361208251338.cognitiveservices.azure.com/", key="45PYY2Av9CdMCveAjVG43MGKrnHzSxdiFTK9mWBgrOsMAHavxKj0JQQJ99BDACHYHv6XJ3w3AAAAACOGeVpQ", unused_param=None, file_path=image_filename)
|
| 1208 |
+
if image_text:
|
| 1209 |
+
logger.info(f"处理图片输入生成ToDo: {image_text}")
|
| 1210 |
+
image_filtered_result = filter_message_with_llm(image_text)
|
| 1211 |
+
image_todos = process_filtered_result_for_todo(image_filtered_result, image_text, "图片")
|
| 1212 |
+
new_todos_list.extend(image_todos)
|
| 1213 |
+
except Exception as e:
|
| 1214 |
+
logger.error(f"图片处理错误: {str(e)}")
|
| 1215 |
+
|
| 1216 |
+
# 如果没有任何有效输入,使用原有逻辑
|
| 1217 |
+
if not new_todos_list and final_user_content:
|
| 1218 |
+
logger.info(f"使用整合内容生成ToDo: {final_user_content}")
|
| 1219 |
+
filtered_result = filter_message_with_llm(final_user_content)
|
| 1220 |
+
|
| 1221 |
+
if isinstance(filtered_result, dict) and "error" in filtered_result:
|
| 1222 |
+
logger.error(f"Filter 模块处理失败: {filtered_result['error']}")
|
| 1223 |
+
# 可以选择在这里显示错误信息给用户
|
| 1224 |
+
new_todos_list = [["Error", filtered_result['error'], "Filter Failed"]]
|
| 1225 |
+
elif isinstance(filtered_result, dict) and filtered_result.get("分类") == "其他":
|
| 1226 |
+
logger.info(f"消息被 Filter 模块归类为 '其他',不生成 ToDo List。")
|
| 1227 |
+
new_todos_list = [["Info", "消息被归类为 '其他',无需生成 ToDo。", "Filtered"]]
|
| 1228 |
+
elif isinstance(filtered_result, list):
|
| 1229 |
+
# 如果返回的是列表,尝试从列表中获取分类信息
|
| 1230 |
+
category = None
|
| 1231 |
+
|
| 1232 |
+
# 检查列表是否为空
|
| 1233 |
+
if not filtered_result:
|
| 1234 |
+
logger.warning("Filter 模块返回了空列表,将继续生成 ToDo List。")
|
| 1235 |
+
if final_user_content:
|
| 1236 |
+
msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
|
| 1237 |
+
new_todos_list = generate_todolist_from_text(final_user_content, msg_id_todo)
|
| 1238 |
+
# 将新的待办事项添加到全局列表中
|
| 1239 |
+
if new_todos_list and not (len(new_todos_list) == 1 and "Info" in str(new_todos_list[0])):
|
| 1240 |
+
# 重新分配ID以确保连续性
|
| 1241 |
+
for i, todo in enumerate(new_todos_list):
|
| 1242 |
+
todo[0] = len(all_todos_global) + i + 1
|
| 1243 |
+
all_todos_global.extend(new_todos_list)
|
| 1244 |
+
yield ch_history, all_todos_global
|
| 1245 |
+
return
|
| 1246 |
+
|
| 1247 |
+
# 确保列表中至少有一个元素且是字典类型
|
| 1248 |
+
valid_item = None
|
| 1249 |
+
for item in filtered_result:
|
| 1250 |
+
if isinstance(item, dict):
|
| 1251 |
+
valid_item = item
|
| 1252 |
+
if "分类" in item:
|
| 1253 |
+
category = item["分类"]
|
| 1254 |
+
break
|
| 1255 |
+
|
| 1256 |
+
# 如果没有找到有效的字典元素,记录警告并继续生成ToDo
|
| 1257 |
+
if valid_item is None:
|
| 1258 |
+
logger.warning(f"Filter 模块返回的列表中没有有效的字典元素: {filtered_result}")
|
| 1259 |
+
if final_user_content:
|
| 1260 |
+
msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
|
| 1261 |
+
new_todos_list = generate_todolist_from_text(final_user_content, msg_id_todo)
|
| 1262 |
+
# 将新的待办事项添加到全局列表中
|
| 1263 |
+
if new_todos_list and not (len(new_todos_list) == 1 and "Info" in str(new_todos_list[0])):
|
| 1264 |
+
# 重新分配ID以确保连续性
|
| 1265 |
+
for i, todo in enumerate(new_todos_list):
|
| 1266 |
+
todo[0] = len(all_todos_global) + i + 1
|
| 1267 |
+
all_todos_global.extend(new_todos_list)
|
| 1268 |
+
yield ch_history, all_todos_global
|
| 1269 |
+
return
|
| 1270 |
+
|
| 1271 |
+
if category == "其他":
|
| 1272 |
+
logger.info(f"消息被 Filter 模块归类为 '其他',不生成 ToDo List。")
|
| 1273 |
+
new_todos_list = [["Info", "消息被归类为 '其他',无需生成 ToDo。", "Filtered"]]
|
| 1274 |
+
else:
|
| 1275 |
+
logger.info(f"消息被 Filter 模块归类为 '{category if category else '未知'}',继续生成 ToDo List。")
|
| 1276 |
+
# 如果 Filter 结果不是"其他",则继续生成 ToDoList
|
| 1277 |
+
if final_user_content:
|
| 1278 |
+
msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
|
| 1279 |
+
new_todos_list = generate_todolist_from_text(final_user_content, msg_id_todo)
|
| 1280 |
+
else:
|
| 1281 |
+
# 如果是字典但不是"其他"分类
|
| 1282 |
+
logger.info(f"消息被 Filter 模块归类为 '{filtered_result.get('分类')}',继续生成 ToDo List。")
|
| 1283 |
+
# 如果 Filter 结果不是"其他",则继续生成 ToDoList
|
| 1284 |
+
if final_user_content:
|
| 1285 |
+
msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
|
| 1286 |
+
new_todos_list = generate_todolist_from_text(final_user_content, msg_id_todo)
|
| 1287 |
+
|
| 1288 |
+
# 将新的待办事项添加到全局列表中(排除信息性消息)
|
| 1289 |
+
if new_todos_list and not (len(new_todos_list) == 1 and ("Info" in str(new_todos_list[0]) or "Error" in str(new_todos_list[0]))):
|
| 1290 |
+
# 重新分配ID以确保连续性
|
| 1291 |
+
for i, todo in enumerate(new_todos_list):
|
| 1292 |
+
todo[0] = len(all_todos_global) + i + 1
|
| 1293 |
+
all_todos_global.extend(new_todos_list)
|
| 1294 |
+
|
| 1295 |
+
yield ch_history, all_todos_global # 最终更新聊天和完整的ToDo列表
|
| 1296 |
+
|
| 1297 |
+
submit_btn.click(
|
| 1298 |
+
handle_submit,
|
| 1299 |
+
[msg, chatbot, system_msg, max_tokens_slider, temperature_slider, top_p_slider, audio_input, image_input],
|
| 1300 |
+
[chatbot, todolist_df]
|
| 1301 |
+
)
|
| 1302 |
+
msg.submit(
|
| 1303 |
+
handle_submit,
|
| 1304 |
+
[msg, chatbot, system_msg, max_tokens_slider, temperature_slider, top_p_slider, audio_input, image_input],
|
| 1305 |
+
[chatbot, todolist_df]
|
| 1306 |
+
)
|
| 1307 |
+
|
| 1308 |
+
def clear_all():
|
| 1309 |
+
global all_todos_global
|
| 1310 |
+
all_todos_global = [] # 清除全局待办事项列表
|
| 1311 |
+
return None, None, "" # 清除 chatbot, todolist_df, 和 msg 输入框
|
| 1312 |
+
clear_btn.click(clear_all, None, [chatbot, todolist_df, msg], queue=False)
|
| 1313 |
+
|
| 1314 |
+
# 旧的 Audio/Image Processing Tab (保持不变或按需修改)
|
| 1315 |
+
with gr.Tab("Audio/Image Processing (Original)"):
|
| 1316 |
+
gr.Markdown("## 处理音频和图片")
|
| 1317 |
+
audio_processor = gr.Audio(label="上传音频", type="numpy")
|
| 1318 |
+
image_processor = gr.Image(label="上传图片", type="numpy")
|
| 1319 |
+
process_btn = gr.Button("处理", variant="primary")
|
| 1320 |
+
audio_output = gr.Textbox(label="音频信息")
|
| 1321 |
+
image_output = gr.Textbox(label="图片信息")
|
| 1322 |
+
|
| 1323 |
+
process_btn.click(
|
| 1324 |
+
process,
|
| 1325 |
+
inputs=[audio_processor, image_processor],
|
| 1326 |
+
outputs=[audio_output, image_output]
|
| 1327 |
+
)
|
| 1328 |
+
|
| 1329 |
+
if __name__ == "__main__":
|
| 1330 |
+
app.launch(debug=True)
|
app_pro.py
ADDED
|
@@ -0,0 +1,840 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import json
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import yaml
|
| 5 |
+
import re
|
| 6 |
+
import logging
|
| 7 |
+
import io
|
| 8 |
+
import sys
|
| 9 |
+
import re
|
| 10 |
+
from datetime import datetime, timezone, timedelta
|
| 11 |
+
import requests
|
| 12 |
+
from tools import * #gege的多模态
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
CONFIG = None
|
| 16 |
+
HF_CONFIG_PATH = Path(__file__).parent / "todogen_LLM_config.yaml"
|
| 17 |
+
|
| 18 |
+
def load_hf_config():
|
| 19 |
+
"""加载YAML配置文件"""
|
| 20 |
+
global CONFIG
|
| 21 |
+
if CONFIG is None:
|
| 22 |
+
try:
|
| 23 |
+
with open(HF_CONFIG_PATH, 'r', encoding='utf-8') as f:
|
| 24 |
+
CONFIG = yaml.safe_load(f)
|
| 25 |
+
print(f"✅ 配置已加载: {HF_CONFIG_PATH}")
|
| 26 |
+
except FileNotFoundError:
|
| 27 |
+
print(f"❌ 错误: 配置文件 {HF_CONFIG_PATH} 未找到。请确保它在 hf 目录下。")
|
| 28 |
+
CONFIG = {}
|
| 29 |
+
except Exception as e:
|
| 30 |
+
print(f"❌ 加载配置文件 {HF_CONFIG_PATH} 时出错: {e}")
|
| 31 |
+
CONFIG = {}
|
| 32 |
+
return CONFIG
|
| 33 |
+
|
| 34 |
+
def get_hf_openai_config():
|
| 35 |
+
"""获取OpenAI API配置"""
|
| 36 |
+
config = load_hf_config()
|
| 37 |
+
return config.get('openai', {})
|
| 38 |
+
|
| 39 |
+
def get_hf_openai_filter_config():
|
| 40 |
+
"""获取Filter API配置"""
|
| 41 |
+
config = load_hf_config()
|
| 42 |
+
return config.get('openai_filter', {})
|
| 43 |
+
|
| 44 |
+
def get_hf_xunfei_config():
|
| 45 |
+
"""获取讯飞API配置"""
|
| 46 |
+
config = load_hf_config()
|
| 47 |
+
return config.get('xunfei', {})
|
| 48 |
+
|
| 49 |
+
def get_hf_paths_config():
|
| 50 |
+
"""获取文件路径配置"""
|
| 51 |
+
config = load_hf_config()
|
| 52 |
+
base = Path(__file__).resolve().parent
|
| 53 |
+
paths_cfg = config.get('paths', {})
|
| 54 |
+
return {
|
| 55 |
+
'base_dir': base,
|
| 56 |
+
'prompt_template': base / paths_cfg.get('prompt_template', 'prompt_template.txt'),
|
| 57 |
+
'true_positive_examples': base / paths_cfg.get('true_positive_examples', 'TruePositive_few_shot.txt'),
|
| 58 |
+
'false_positive_examples': base / paths_cfg.get('false_positive_examples', 'FalsePositive_few_shot.txt'),
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
llm_config = get_hf_openai_config()
|
| 62 |
+
NVIDIA_API_BASE_URL = llm_config.get('base_url')
|
| 63 |
+
NVIDIA_API_KEY = llm_config.get('api_key')
|
| 64 |
+
NVIDIA_MODEL_NAME = llm_config.get('model')
|
| 65 |
+
|
| 66 |
+
filter_config = get_hf_openai_filter_config()
|
| 67 |
+
Filter_API_BASE_URL = filter_config.get('base_url_filter')
|
| 68 |
+
Filter_API_KEY = filter_config.get('api_key_filter')
|
| 69 |
+
Filter_MODEL_NAME = filter_config.get('model_filter')
|
| 70 |
+
|
| 71 |
+
if not NVIDIA_API_BASE_URL or not NVIDIA_API_KEY or not NVIDIA_MODEL_NAME:
|
| 72 |
+
print("❌ 错误: NVIDIA API 配置不完整。请检查 todogen_LLM_config.yaml 中的 openai 部分。")
|
| 73 |
+
NVIDIA_API_BASE_URL = ""
|
| 74 |
+
NVIDIA_API_KEY = ""
|
| 75 |
+
NVIDIA_MODEL_NAME = ""
|
| 76 |
+
|
| 77 |
+
if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME:
|
| 78 |
+
print("❌ 错误: Filter API 配置不完整。请检查 todogen_LLM_config.yaml 中的 openai_filter 部分。")
|
| 79 |
+
Filter_API_BASE_URL = ""
|
| 80 |
+
Filter_API_KEY = ""
|
| 81 |
+
Filter_MODEL_NAME = ""
|
| 82 |
+
|
| 83 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 84 |
+
logger = logging.getLogger(__name__)
|
| 85 |
+
|
| 86 |
+
def load_single_few_shot_file_hf(file_path: Path) -> str:
|
| 87 |
+
"""加载单个few-shot示例文件并转义大括号"""
|
| 88 |
+
try:
|
| 89 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 90 |
+
content = f.read()
|
| 91 |
+
escaped_content = content.replace('{', '{{').replace('}', '}}')
|
| 92 |
+
return escaped_content
|
| 93 |
+
except FileNotFoundError:
|
| 94 |
+
return ""
|
| 95 |
+
except Exception:
|
| 96 |
+
return ""
|
| 97 |
+
|
| 98 |
+
PROMPT_TEMPLATE_CONTENT = ""
|
| 99 |
+
TRUE_POSITIVE_EXAMPLES_CONTENT = ""
|
| 100 |
+
FALSE_POSITIVE_EXAMPLES_CONTENT = ""
|
| 101 |
+
|
| 102 |
+
def load_prompt_data_hf():
|
| 103 |
+
"""加载提示词模板和示例数据"""
|
| 104 |
+
global PROMPT_TEMPLATE_CONTENT, TRUE_POSITIVE_EXAMPLES_CONTENT, FALSE_POSITIVE_EXAMPLES_CONTENT
|
| 105 |
+
paths = get_hf_paths_config()
|
| 106 |
+
try:
|
| 107 |
+
with open(paths['prompt_template'], 'r', encoding='utf-8') as f:
|
| 108 |
+
PROMPT_TEMPLATE_CONTENT = f.read()
|
| 109 |
+
except FileNotFoundError:
|
| 110 |
+
PROMPT_TEMPLATE_CONTENT = "Error: Prompt template not found."
|
| 111 |
+
|
| 112 |
+
TRUE_POSITIVE_EXAMPLES_CONTENT = load_single_few_shot_file_hf(paths['true_positive_examples'])
|
| 113 |
+
FALSE_POSITIVE_EXAMPLES_CONTENT = load_single_few_shot_file_hf(paths['false_positive_examples'])
|
| 114 |
+
|
| 115 |
+
load_prompt_data_hf()
|
| 116 |
+
|
| 117 |
+
def _process_parsed_json(parsed_data):
|
| 118 |
+
"""处理解析后的JSON数据,确保格式正确"""
|
| 119 |
+
try:
|
| 120 |
+
if isinstance(parsed_data, list):
|
| 121 |
+
if not parsed_data:
|
| 122 |
+
return [{}]
|
| 123 |
+
|
| 124 |
+
processed_list = []
|
| 125 |
+
for item in parsed_data:
|
| 126 |
+
if isinstance(item, dict):
|
| 127 |
+
processed_list.append(item)
|
| 128 |
+
else:
|
| 129 |
+
try:
|
| 130 |
+
processed_list.append({"content": str(item)})
|
| 131 |
+
except:
|
| 132 |
+
processed_list.append({"content": "无法转换的项目"})
|
| 133 |
+
|
| 134 |
+
if not processed_list:
|
| 135 |
+
return [{}]
|
| 136 |
+
|
| 137 |
+
return processed_list
|
| 138 |
+
|
| 139 |
+
elif isinstance(parsed_data, dict):
|
| 140 |
+
return parsed_data
|
| 141 |
+
|
| 142 |
+
else:
|
| 143 |
+
return {"content": str(parsed_data)}
|
| 144 |
+
|
| 145 |
+
except Exception as e:
|
| 146 |
+
return {"error": f"Error processing parsed JSON: {e}"}
|
| 147 |
+
|
| 148 |
+
def json_parser(text: str) -> dict:
|
| 149 |
+
"""从文本中解析JSON数据,支持多种格式"""
|
| 150 |
+
try:
|
| 151 |
+
try:
|
| 152 |
+
parsed_data = json.loads(text)
|
| 153 |
+
return _process_parsed_json(parsed_data)
|
| 154 |
+
except json.JSONDecodeError:
|
| 155 |
+
pass
|
| 156 |
+
|
| 157 |
+
match = re.search(r'```(?:json)?\n(.*?)```', text, re.DOTALL)
|
| 158 |
+
if match:
|
| 159 |
+
json_str = match.group(1).strip()
|
| 160 |
+
json_str = re.sub(r',\s*]', ']', json_str)
|
| 161 |
+
json_str = re.sub(r',\s*}', '}', json_str)
|
| 162 |
+
try:
|
| 163 |
+
parsed_data = json.loads(json_str)
|
| 164 |
+
return _process_parsed_json(parsed_data)
|
| 165 |
+
except json.JSONDecodeError:
|
| 166 |
+
pass
|
| 167 |
+
|
| 168 |
+
array_match = re.search(r'\[\s*\{.*?\}\s*(?:,\s*\{.*?\}\s*)*\]', text, re.DOTALL)
|
| 169 |
+
if array_match:
|
| 170 |
+
potential_json = array_match.group(0).strip()
|
| 171 |
+
try:
|
| 172 |
+
parsed_data = json.loads(potential_json)
|
| 173 |
+
return _process_parsed_json(parsed_data)
|
| 174 |
+
except json.JSONDecodeError:
|
| 175 |
+
pass
|
| 176 |
+
|
| 177 |
+
object_match = re.search(r'\{.*?\}', text, re.DOTALL)
|
| 178 |
+
if object_match:
|
| 179 |
+
potential_json = object_match.group(0).strip()
|
| 180 |
+
try:
|
| 181 |
+
parsed_data = json.loads(potential_json)
|
| 182 |
+
return _process_parsed_json(parsed_data)
|
| 183 |
+
except json.JSONDecodeError:
|
| 184 |
+
pass
|
| 185 |
+
|
| 186 |
+
return {"error": "No valid JSON block found or failed to parse", "raw_text": text}
|
| 187 |
+
|
| 188 |
+
except Exception as e:
|
| 189 |
+
return {"error": f"Unexpected error in json_parser: {e}", "raw_text": text}
|
| 190 |
+
|
| 191 |
+
def filter_message_with_llm(text_input: str, message_id: str = "user_input_001"):
|
| 192 |
+
"""使用LLM对消息进行分类过滤"""
|
| 193 |
+
mock_data = [(text_input, message_id)]
|
| 194 |
+
|
| 195 |
+
system_prompt = """
|
| 196 |
+
# 角色
|
| 197 |
+
你是一个专业的短信内容分析助手,根据输入判断内容的类型及可信度,为用户使用信息提供依据和便利。
|
| 198 |
+
|
| 199 |
+
# 任务
|
| 200 |
+
对于输入的多条数据,分析每一条数据内容(主键:`message_id`)属于【物流取件、缴费充值、待付(还)款、会议邀约、其他】的可能性百分比。
|
| 201 |
+
主要对于聊天、问候、回执、结果通知、上月账单等信息不需要收件人进行下一步处理的信息,直接归到其他类进行忽略
|
| 202 |
+
|
| 203 |
+
# 要求
|
| 204 |
+
1. 以json格式输出
|
| 205 |
+
2. content简洁提炼关键词,字符数<20以内
|
| 206 |
+
3. 输入条数和输出条数完全一样
|
| 207 |
+
|
| 208 |
+
# 输出示例
|
| 209 |
+
```
|
| 210 |
+
[
|
| 211 |
+
{"message_id":"1111111","content":"账单805.57元待还","物流取件":0,"欠费缴纳":99,"待付(还)款":1,"会议邀约":0,"其他":0, "分类":"欠费缴纳"},
|
| 212 |
+
{"message_id":"222222","content":"邀请你加入飞书视频会议","物流取件":0,"欠费缴纳":0,"待付(还)款":1,"会议邀约":100,"其他":0, "分类":"会议邀约"}
|
| 213 |
+
]
|
| 214 |
+
```
|
| 215 |
+
"""
|
| 216 |
+
|
| 217 |
+
llm_messages = [
|
| 218 |
+
{"role": "system", "content": system_prompt},
|
| 219 |
+
{"role": "user", "content": str(mock_data)}
|
| 220 |
+
]
|
| 221 |
+
|
| 222 |
+
try:
|
| 223 |
+
if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME:
|
| 224 |
+
return [{"error": "Filter API configuration incomplete", "-": "-"}]
|
| 225 |
+
|
| 226 |
+
headers = {
|
| 227 |
+
"Authorization": f"Bearer {Filter_API_KEY}",
|
| 228 |
+
"Accept": "application/json"
|
| 229 |
+
}
|
| 230 |
+
payload = {
|
| 231 |
+
"model": Filter_MODEL_NAME,
|
| 232 |
+
"messages": llm_messages,
|
| 233 |
+
"temperature": 0.0,
|
| 234 |
+
"top_p": 0.95,
|
| 235 |
+
"max_tokens": 1024,
|
| 236 |
+
"stream": False
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
api_url = f"{Filter_API_BASE_URL}/chat/completions"
|
| 240 |
+
|
| 241 |
+
try:
|
| 242 |
+
response = requests.post(api_url, headers=headers, json=payload)
|
| 243 |
+
response.raise_for_status()
|
| 244 |
+
raw_llm_response = response.json()["choices"][0]["message"]["content"]
|
| 245 |
+
except requests.exceptions.RequestException as e:
|
| 246 |
+
return [{"error": f"Filter API call failed: {e}", "-": "-"}]
|
| 247 |
+
|
| 248 |
+
raw_llm_response = raw_llm_response.replace("```json", "").replace("```", "")
|
| 249 |
+
parsed_filter_data = json_parser(raw_llm_response)
|
| 250 |
+
|
| 251 |
+
if "error" in parsed_filter_data:
|
| 252 |
+
return [{"error": f"Filter LLM response parsing error: {parsed_filter_data['error']}"}]
|
| 253 |
+
|
| 254 |
+
if isinstance(parsed_filter_data, list) and parsed_filter_data:
|
| 255 |
+
for item in parsed_filter_data:
|
| 256 |
+
if isinstance(item, dict) and item.get("分类") == "欠费缴纳" and "缴费支出" in item.get("content", ""):
|
| 257 |
+
item["分类"] = "其他"
|
| 258 |
+
|
| 259 |
+
request_id_list = {message_id}
|
| 260 |
+
response_id_list = {item.get('message_id') for item in parsed_filter_data if isinstance(item, dict)}
|
| 261 |
+
diff = request_id_list - response_id_list
|
| 262 |
+
|
| 263 |
+
if diff:
|
| 264 |
+
for missed_id in diff:
|
| 265 |
+
parsed_filter_data.append({
|
| 266 |
+
"message_id": missed_id,
|
| 267 |
+
"content": text_input[:20],
|
| 268 |
+
"物流取件": 0,
|
| 269 |
+
"欠费缴纳": 0,
|
| 270 |
+
"待付(还)款": 0,
|
| 271 |
+
"会议邀约": 0,
|
| 272 |
+
"其他": 100,
|
| 273 |
+
"分类": "其他"
|
| 274 |
+
})
|
| 275 |
+
|
| 276 |
+
return parsed_filter_data
|
| 277 |
+
else:
|
| 278 |
+
return [{
|
| 279 |
+
"message_id": message_id,
|
| 280 |
+
"content": text_input[:20],
|
| 281 |
+
"物流取件": 0,
|
| 282 |
+
"欠费缴纳": 0,
|
| 283 |
+
"待付(还)款": 0,
|
| 284 |
+
"会议邀约": 0,
|
| 285 |
+
"其他": 100,
|
| 286 |
+
"分类": "其他",
|
| 287 |
+
"error": "Filter LLM returned empty or unexpected format"
|
| 288 |
+
}]
|
| 289 |
+
|
| 290 |
+
except Exception as e:
|
| 291 |
+
return [{
|
| 292 |
+
"message_id": message_id,
|
| 293 |
+
"content": text_input[:20],
|
| 294 |
+
"物流取件": 0,
|
| 295 |
+
"欠费缴纳": 0,
|
| 296 |
+
"待付(还)款": 0,
|
| 297 |
+
"会议邀约": 0,
|
| 298 |
+
"其他": 100,
|
| 299 |
+
"分类": "其他",
|
| 300 |
+
"error": f"Filter LLM call/parse error: {str(e)}"
|
| 301 |
+
}]
|
| 302 |
+
|
| 303 |
+
def generate_todolist_from_text(text_input: str, message_id: str = "user_input_001"):
|
| 304 |
+
"""从文本生成待办事项列表"""
|
| 305 |
+
if not PROMPT_TEMPLATE_CONTENT or "Error:" in PROMPT_TEMPLATE_CONTENT:
|
| 306 |
+
return [["error", "Prompt template not loaded", "-"]]
|
| 307 |
+
|
| 308 |
+
current_time_iso = datetime.now(timezone.utc).isoformat()
|
| 309 |
+
content_escaped = text_input.replace('{', '{{').replace('}', '}}')
|
| 310 |
+
|
| 311 |
+
formatted_prompt = PROMPT_TEMPLATE_CONTENT.format(
|
| 312 |
+
true_positive_examples=TRUE_POSITIVE_EXAMPLES_CONTENT,
|
| 313 |
+
false_positive_examples=FALSE_POSITIVE_EXAMPLES_CONTENT,
|
| 314 |
+
current_time=current_time_iso,
|
| 315 |
+
message_id=message_id,
|
| 316 |
+
content_escaped=content_escaped
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
enhanced_prompt = formatted_prompt + """
|
| 320 |
+
|
| 321 |
+
# 重要提示
|
| 322 |
+
请确保你的回复是有效的JSON格式,并且只包含JSON内容。不要添加任何额外的解释或文本。
|
| 323 |
+
你的回复应该严格按照上面的输出示例格式,只包含JSON对象,不要有任何其他文本。
|
| 324 |
+
"""
|
| 325 |
+
|
| 326 |
+
llm_messages = [
|
| 327 |
+
{"role": "user", "content": enhanced_prompt}
|
| 328 |
+
]
|
| 329 |
+
|
| 330 |
+
try:
|
| 331 |
+
if ("充值" in text_input or "缴费" in text_input) and ("移动" in text_input or "话费" in text_input or "余额" in text_input):
|
| 332 |
+
todo_item = {
|
| 333 |
+
message_id: {
|
| 334 |
+
"is_todo": True,
|
| 335 |
+
"end_time": (datetime.now(timezone.utc) + timedelta(days=3)).isoformat(),
|
| 336 |
+
"location": "线上:中国移动APP",
|
| 337 |
+
"todo_content": "缴纳话费",
|
| 338 |
+
"urgency": "important"
|
| 339 |
+
}
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
todo_content = "缴纳话费"
|
| 343 |
+
end_time = todo_item[message_id]["end_time"].split("T")[0]
|
| 344 |
+
location = todo_item[message_id]["location"]
|
| 345 |
+
|
| 346 |
+
combined_content = f"{todo_content} (截止时间: {end_time}, 地点: {location})"
|
| 347 |
+
|
| 348 |
+
output_for_df = []
|
| 349 |
+
output_for_df.append([1, combined_content, "重要"])
|
| 350 |
+
|
| 351 |
+
return output_for_df
|
| 352 |
+
|
| 353 |
+
elif "会议" in text_input and ("邀请" in text_input or "参加" in text_input):
|
| 354 |
+
meeting_time = None
|
| 355 |
+
meeting_pattern = r'(\d{1,2}[月/-]\d{1,2}[日号]?\s*\d{1,2}[点:]\d{0,2}|\d{4}[年/-]\d{1,2}[月/-]\d{1,2}[日号]?\s*\d{1,2}[点:]\d{0,2})'
|
| 356 |
+
meeting_match = re.search(meeting_pattern, text_input)
|
| 357 |
+
|
| 358 |
+
if meeting_match:
|
| 359 |
+
meeting_time = (datetime.now(timezone.utc) + timedelta(days=1, hours=2)).isoformat()
|
| 360 |
+
else:
|
| 361 |
+
meeting_time = (datetime.now(timezone.utc) + timedelta(days=1)).isoformat()
|
| 362 |
+
|
| 363 |
+
todo_item = {
|
| 364 |
+
message_id: {
|
| 365 |
+
"is_todo": True,
|
| 366 |
+
"end_time": meeting_time,
|
| 367 |
+
"location": "线上:会议软件",
|
| 368 |
+
"todo_content": "参加会议",
|
| 369 |
+
"urgency": "important"
|
| 370 |
+
}
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
todo_content = "参加会议"
|
| 374 |
+
end_time = todo_item[message_id]["end_time"].split("T")[0]
|
| 375 |
+
location = todo_item[message_id]["location"]
|
| 376 |
+
|
| 377 |
+
combined_content = f"{todo_content} (截止时间: {end_time}, 地点: {location})"
|
| 378 |
+
|
| 379 |
+
output_for_df = []
|
| 380 |
+
output_for_df.append([1, combined_content, "重要"])
|
| 381 |
+
|
| 382 |
+
return output_for_df
|
| 383 |
+
|
| 384 |
+
elif ("快递" in text_input or "物流" in text_input or "取件" in text_input) and ("到达" in text_input or "取件码" in text_input or "柜" in text_input):
|
| 385 |
+
pickup_code = None
|
| 386 |
+
code_pattern = r'取件码[是为:]?\s*(\d{4,6})'
|
| 387 |
+
code_match = re.search(code_pattern, text_input)
|
| 388 |
+
|
| 389 |
+
todo_content = "取快递"
|
| 390 |
+
if code_match:
|
| 391 |
+
pickup_code = code_match.group(1)
|
| 392 |
+
todo_content = f"取快递(取件码:{pickup_code})"
|
| 393 |
+
|
| 394 |
+
todo_item = {
|
| 395 |
+
message_id: {
|
| 396 |
+
"is_todo": True,
|
| 397 |
+
"end_time": (datetime.now(timezone.utc) + timedelta(days=2)).isoformat(),
|
| 398 |
+
"location": "线下:快递柜",
|
| 399 |
+
"todo_content": todo_content,
|
| 400 |
+
"urgency": "important"
|
| 401 |
+
}
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
end_time = todo_item[message_id]["end_time"].split("T")[0]
|
| 405 |
+
location = todo_item[message_id]["location"]
|
| 406 |
+
|
| 407 |
+
combined_content = f"{todo_content} (截止时间: {end_time}, 地点: {location})"
|
| 408 |
+
|
| 409 |
+
output_for_df = []
|
| 410 |
+
output_for_df.append([1, combined_content, "重要"])
|
| 411 |
+
|
| 412 |
+
return output_for_df
|
| 413 |
+
|
| 414 |
+
if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME:
|
| 415 |
+
return [["error", "Filter API configuration incomplete", "-"]]
|
| 416 |
+
|
| 417 |
+
headers = {
|
| 418 |
+
"Authorization": f"Bearer {Filter_API_KEY}",
|
| 419 |
+
"Accept": "application/json"
|
| 420 |
+
}
|
| 421 |
+
payload = {
|
| 422 |
+
"model": Filter_MODEL_NAME,
|
| 423 |
+
"messages": llm_messages,
|
| 424 |
+
"temperature": 0.2,
|
| 425 |
+
"top_p": 0.95,
|
| 426 |
+
"max_tokens": 1024,
|
| 427 |
+
"stream": False
|
| 428 |
+
}
|
| 429 |
+
|
| 430 |
+
api_url = f"{Filter_API_BASE_URL}/chat/completions"
|
| 431 |
+
|
| 432 |
+
try:
|
| 433 |
+
response = requests.post(api_url, headers=headers, json=payload)
|
| 434 |
+
response.raise_for_status()
|
| 435 |
+
raw_llm_response = response.json()['choices'][0]['message']['content']
|
| 436 |
+
except requests.exceptions.RequestException as e:
|
| 437 |
+
return [["error", f"Filter API call failed: {e}", "-"]]
|
| 438 |
+
|
| 439 |
+
parsed_todos_data = json_parser(raw_llm_response)
|
| 440 |
+
|
| 441 |
+
if "error" in parsed_todos_data:
|
| 442 |
+
return [["error", f"LLM response parsing error: {parsed_todos_data['error']}", parsed_todos_data.get('raw_text', '')[:50] + "..."]]
|
| 443 |
+
|
| 444 |
+
output_for_df = []
|
| 445 |
+
|
| 446 |
+
if isinstance(parsed_todos_data, dict):
|
| 447 |
+
todo_info = None
|
| 448 |
+
for key, value in parsed_todos_data.items():
|
| 449 |
+
if key == message_id or key == str(message_id):
|
| 450 |
+
todo_info = value
|
| 451 |
+
break
|
| 452 |
+
|
| 453 |
+
if todo_info and isinstance(todo_info, dict) and todo_info.get("is_todo", False):
|
| 454 |
+
todo_content = todo_info.get("todo_content", "未指定待办内容")
|
| 455 |
+
end_time = todo_info.get("end_time")
|
| 456 |
+
location = todo_info.get("location")
|
| 457 |
+
urgency = todo_info.get("urgency", "unimportant")
|
| 458 |
+
|
| 459 |
+
combined_content = todo_content
|
| 460 |
+
|
| 461 |
+
if end_time and end_time != "null":
|
| 462 |
+
try:
|
| 463 |
+
date_part = end_time.split("T")[0] if "T" in end_time else end_time
|
| 464 |
+
combined_content += f" (截止时间: {date_part}"
|
| 465 |
+
except:
|
| 466 |
+
combined_content += f" (截止时间: {end_time}"
|
| 467 |
+
else:
|
| 468 |
+
combined_content += " ("
|
| 469 |
+
|
| 470 |
+
if location and location != "null":
|
| 471 |
+
combined_content += f", 地点: {location})"
|
| 472 |
+
else:
|
| 473 |
+
combined_content += ")"
|
| 474 |
+
|
| 475 |
+
urgency_display = "一般"
|
| 476 |
+
if urgency == "urgent":
|
| 477 |
+
urgency_display = "紧急"
|
| 478 |
+
elif urgency == "important":
|
| 479 |
+
urgency_display = "重要"
|
| 480 |
+
|
| 481 |
+
output_for_df = []
|
| 482 |
+
output_for_df.append([1, combined_content, urgency_display])
|
| 483 |
+
else:
|
| 484 |
+
output_for_df = []
|
| 485 |
+
output_for_df.append([1, "此消息不包含待办事项", "-"])
|
| 486 |
+
|
| 487 |
+
elif isinstance(parsed_todos_data, list):
|
| 488 |
+
output_for_df = []
|
| 489 |
+
|
| 490 |
+
if not parsed_todos_data:
|
| 491 |
+
return [[1, "未能生成待办事项", "-"]]
|
| 492 |
+
|
| 493 |
+
for i, item in enumerate(parsed_todos_data):
|
| 494 |
+
if isinstance(item, dict):
|
| 495 |
+
todo_content = item.get('todo_content', item.get('content', 'N/A'))
|
| 496 |
+
status = item.get('status', '未完成')
|
| 497 |
+
urgency = item.get('urgency', 'normal')
|
| 498 |
+
|
| 499 |
+
combined_content = todo_content
|
| 500 |
+
|
| 501 |
+
if 'end_time' in item and item['end_time']:
|
| 502 |
+
try:
|
| 503 |
+
if isinstance(item['end_time'], str):
|
| 504 |
+
date_part = item['end_time'].split("T")[0] if "T" in item['end_time'] else item['end_time']
|
| 505 |
+
combined_content += f" (截止时间: {date_part}"
|
| 506 |
+
else:
|
| 507 |
+
combined_content += f" (截止时间: {str(item['end_time'])}"
|
| 508 |
+
except Exception:
|
| 509 |
+
combined_content += " ("
|
| 510 |
+
else:
|
| 511 |
+
combined_content += " ("
|
| 512 |
+
|
| 513 |
+
if 'location' in item and item['location']:
|
| 514 |
+
combined_content += f", 地点: {item['location']})"
|
| 515 |
+
else:
|
| 516 |
+
combined_content += ")"
|
| 517 |
+
|
| 518 |
+
importance = "一般"
|
| 519 |
+
if urgency == "urgent":
|
| 520 |
+
importance = "紧急"
|
| 521 |
+
elif urgency == "important":
|
| 522 |
+
importance = "重要"
|
| 523 |
+
|
| 524 |
+
output_for_df.append([i + 1, combined_content, importance])
|
| 525 |
+
else:
|
| 526 |
+
try:
|
| 527 |
+
item_str = str(item) if item is not None else "未知项目"
|
| 528 |
+
output_for_df.append([i + 1, item_str, "一般"])
|
| 529 |
+
except Exception:
|
| 530 |
+
output_for_df.append([i + 1, "处理错误的项目", "一般"])
|
| 531 |
+
|
| 532 |
+
if not output_for_df:
|
| 533 |
+
return [["info", "未发现待办事项", "-"]]
|
| 534 |
+
|
| 535 |
+
return output_for_df
|
| 536 |
+
|
| 537 |
+
except Exception as e:
|
| 538 |
+
return [["error", f"LLM call/parse error: {str(e)}", "-"]]
|
| 539 |
+
# 这里------多模态数据从这里调用
|
| 540 |
+
def process(audio, image):
|
| 541 |
+
"""处理音频和图片输入,返回基本信息"""
|
| 542 |
+
if audio is not None:
|
| 543 |
+
sample_rate, audio_data = audio
|
| 544 |
+
audio_info = f"音频采样率: {sample_rate}Hz, 数据长度: {len(audio_data)}"
|
| 545 |
+
else:
|
| 546 |
+
audio_info = "未收到音频"
|
| 547 |
+
|
| 548 |
+
if image is not None:
|
| 549 |
+
image_info = f"图片尺寸: {image.shape}"
|
| 550 |
+
else:
|
| 551 |
+
image_info = "未收到图片"
|
| 552 |
+
|
| 553 |
+
return audio_info, image_info
|
| 554 |
+
|
| 555 |
+
def respond(message, history, system_message, max_tokens, temperature, top_p, audio, image):
|
| 556 |
+
"""处理聊天响应,支持流式输出"""
|
| 557 |
+
chat_messages = [{"role": "system", "content": system_message}]
|
| 558 |
+
for val in history:
|
| 559 |
+
if val[0]:
|
| 560 |
+
chat_messages.append({"role": "user", "content": val[0]})
|
| 561 |
+
if val[1]:
|
| 562 |
+
chat_messages.append({"role": "assistant", "content": val[1]})
|
| 563 |
+
chat_messages.append({"role": "user", "content": message})
|
| 564 |
+
|
| 565 |
+
chat_response_stream = ""
|
| 566 |
+
if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME:
|
| 567 |
+
yield "Filter API 配置不完整,无法提供聊天回复。", []
|
| 568 |
+
return
|
| 569 |
+
|
| 570 |
+
headers = {
|
| 571 |
+
"Authorization": f"Bearer {Filter_API_KEY}",
|
| 572 |
+
"Accept": "application/json"
|
| 573 |
+
}
|
| 574 |
+
payload = {
|
| 575 |
+
"model": Filter_MODEL_NAME,
|
| 576 |
+
"messages": chat_messages,
|
| 577 |
+
"temperature": temperature,
|
| 578 |
+
"top_p": top_p,
|
| 579 |
+
"max_tokens": max_tokens,
|
| 580 |
+
"stream": True
|
| 581 |
+
}
|
| 582 |
+
api_url = f"{Filter_API_BASE_URL}/chat/completions"
|
| 583 |
+
|
| 584 |
+
try:
|
| 585 |
+
response = requests.post(api_url, headers=headers, json=payload, stream=True)
|
| 586 |
+
response.raise_for_status()
|
| 587 |
+
|
| 588 |
+
for chunk in response.iter_content(chunk_size=None):
|
| 589 |
+
if chunk:
|
| 590 |
+
try:
|
| 591 |
+
for line in chunk.decode('utf-8').splitlines():
|
| 592 |
+
if line.startswith('data: '):
|
| 593 |
+
json_data = line[len('data: '):]
|
| 594 |
+
if json_data.strip() == '[DONE]':
|
| 595 |
+
break
|
| 596 |
+
data = json.loads(json_data)
|
| 597 |
+
token = data['choices'][0]['delta'].get('content', '')
|
| 598 |
+
if token:
|
| 599 |
+
chat_response_stream += token
|
| 600 |
+
yield chat_response_stream, []
|
| 601 |
+
except json.JSONDecodeError:
|
| 602 |
+
pass
|
| 603 |
+
except Exception as e:
|
| 604 |
+
yield chat_response_stream + f"\n\n错误: {e}", []
|
| 605 |
+
|
| 606 |
+
except requests.exceptions.RequestException as e:
|
| 607 |
+
yield f"调用 NVIDIA API 失败: {e}", []
|
| 608 |
+
# 图片-多模态上传入口
|
| 609 |
+
with gr.Blocks() as app:
|
| 610 |
+
gr.Markdown("# ToDoAgent Multi-Modal Interface with ToDo List")
|
| 611 |
+
|
| 612 |
+
with gr.Row():
|
| 613 |
+
with gr.Column(scale=2):
|
| 614 |
+
gr.Markdown("## Chat Interface")
|
| 615 |
+
chatbot = gr.Chatbot(height=450, label="聊天记录", type="messages")
|
| 616 |
+
msg = gr.Textbox(label="输入消息", placeholder="输入您的问题或待办事项...")
|
| 617 |
+
|
| 618 |
+
with gr.Row():
|
| 619 |
+
audio_input = gr.Audio(label="上传语音", type="numpy", sources=["upload", "microphone"])
|
| 620 |
+
image_input = gr.Image(label="上传图片", type="numpy")
|
| 621 |
+
|
| 622 |
+
with gr.Accordion("高级设置", open=False):
|
| 623 |
+
system_msg = gr.Textbox(value="You are a friendly Chatbot.", label="系统提示")
|
| 624 |
+
max_tokens_slider = gr.Slider(minimum=1, maximum=4096, value=1024, step=1, label="最大生成长度(聊天)")
|
| 625 |
+
temperature_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="温度(聊天)")
|
| 626 |
+
top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p(聊天)")
|
| 627 |
+
|
| 628 |
+
with gr.Row():
|
| 629 |
+
submit_btn = gr.Button("发送", variant="primary")
|
| 630 |
+
clear_btn = gr.Button("清除聊天和ToDo")
|
| 631 |
+
|
| 632 |
+
with gr.Column(scale=1):
|
| 633 |
+
gr.Markdown("## Generated ToDo List")
|
| 634 |
+
todolist_df = gr.DataFrame(headers=["ID", "任务内容", "状态"],
|
| 635 |
+
datatype=["number", "str", "str"],
|
| 636 |
+
row_count=(0, "dynamic"),
|
| 637 |
+
col_count=(3, "fixed"),
|
| 638 |
+
label="待办事项列表")
|
| 639 |
+
|
| 640 |
+
def handle_submit(user_msg_content, ch_history, sys_msg, max_t, temp, t_p, audio_f, image_f):
|
| 641 |
+
"""处理用户提交的消息,生成聊天回复和待办事项"""
|
| 642 |
+
# 首先处理多模态输入,获取多模态内容
|
| 643 |
+
multimodal_text_content = ""
|
| 644 |
+
xunfei_config = get_hf_xunfei_config()
|
| 645 |
+
xunfei_appid = xunfei_config.get('appid')
|
| 646 |
+
xunfei_apikey = xunfei_config.get('apikey')
|
| 647 |
+
xunfei_apisecret = xunfei_config.get('apisecret')
|
| 648 |
+
|
| 649 |
+
# 添加调试日志
|
| 650 |
+
logger.info(f"开始多模态处理 - 音频: {audio_f is not None}, 图像: {image_f is not None}")
|
| 651 |
+
logger.info(f"讯飞配置状态 - appid: {bool(xunfei_appid)}, apikey: {bool(xunfei_apikey)}, apisecret: {bool(xunfei_apisecret)}")
|
| 652 |
+
|
| 653 |
+
# 处理音频输入(独立处理)
|
| 654 |
+
if audio_f is not None and xunfei_appid and xunfei_apikey and xunfei_apisecret:
|
| 655 |
+
logger.info("开始处理音频输入...")
|
| 656 |
+
try:
|
| 657 |
+
import tempfile
|
| 658 |
+
import soundfile as sf
|
| 659 |
+
import os
|
| 660 |
+
|
| 661 |
+
audio_sample_rate, audio_data = audio_f
|
| 662 |
+
logger.info(f"音频信息: 采样率 {audio_sample_rate}Hz, 数据长度 {len(audio_data)}")
|
| 663 |
+
|
| 664 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
|
| 665 |
+
sf.write(temp_audio.name, audio_data, audio_sample_rate)
|
| 666 |
+
temp_audio_path = temp_audio.name
|
| 667 |
+
logger.info(f"音频临时文件已保存: {temp_audio_path}")
|
| 668 |
+
|
| 669 |
+
audio_text = audio_to_str(xunfei_appid, xunfei_apikey, xunfei_apisecret, temp_audio_path)
|
| 670 |
+
logger.info(f"音频识别结果: {audio_text}")
|
| 671 |
+
if audio_text:
|
| 672 |
+
multimodal_text_content += f"音频内容: {audio_text}"
|
| 673 |
+
|
| 674 |
+
os.unlink(temp_audio_path)
|
| 675 |
+
logger.info("音频处理完成")
|
| 676 |
+
except Exception as e:
|
| 677 |
+
logger.error(f"音频处理错误: {str(e)}")
|
| 678 |
+
elif audio_f is not None:
|
| 679 |
+
logger.warning("音频文件存在但讯飞配置不完整,跳过音频处理")
|
| 680 |
+
|
| 681 |
+
# 处理图像输入(独立处理)
|
| 682 |
+
if image_f is not None and xunfei_appid and xunfei_apikey and xunfei_apisecret:
|
| 683 |
+
logger.info("开始处理图像输入...")
|
| 684 |
+
try:
|
| 685 |
+
import tempfile
|
| 686 |
+
from PIL import Image
|
| 687 |
+
import os
|
| 688 |
+
|
| 689 |
+
logger.info(f"图像信息: 形状 {image_f.shape}, 数据类型 {image_f.dtype}")
|
| 690 |
+
|
| 691 |
+
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_image:
|
| 692 |
+
if len(image_f.shape) == 3: # RGB图像
|
| 693 |
+
pil_image = Image.fromarray(image_f.astype('uint8'), 'RGB')
|
| 694 |
+
else: # 灰度图像
|
| 695 |
+
pil_image = Image.fromarray(image_f.astype('uint8'), 'L')
|
| 696 |
+
|
| 697 |
+
pil_image.save(temp_image.name, 'JPEG')
|
| 698 |
+
temp_image_path = temp_image.name
|
| 699 |
+
logger.info(f"图像临时文件已保存: {temp_image_path}")
|
| 700 |
+
|
| 701 |
+
image_text = image_to_str(xunfei_appid, xunfei_apikey, xunfei_apisecret, temp_image_path)
|
| 702 |
+
logger.info(f"图像识别结果: {image_text}")
|
| 703 |
+
if image_text:
|
| 704 |
+
if multimodal_text_content: # 如果已有音频内容,添加分隔符
|
| 705 |
+
multimodal_text_content += "\n"
|
| 706 |
+
multimodal_text_content += f"图像内容: {image_text}"
|
| 707 |
+
|
| 708 |
+
os.unlink(temp_image_path)
|
| 709 |
+
logger.info("图像处理完成")
|
| 710 |
+
except Exception as e:
|
| 711 |
+
logger.error(f"图像处理错误: {str(e)}")
|
| 712 |
+
elif image_f is not None:
|
| 713 |
+
logger.warning("图像文件存在但讯飞配置不完整,跳过图像处理")
|
| 714 |
+
|
| 715 |
+
# 确定最终的用户输入内容:如果用户没有输入文本,使用多模态识别的内容
|
| 716 |
+
final_user_content = user_msg_content.strip() if user_msg_content else ""
|
| 717 |
+
if not final_user_content and multimodal_text_content:
|
| 718 |
+
final_user_content = multimodal_text_content
|
| 719 |
+
logger.info(f"用户无文本输入,使用多模态内容作为用户输入: {final_user_content}")
|
| 720 |
+
elif final_user_content and multimodal_text_content:
|
| 721 |
+
# 用户有文本输入,多模态内容作为补充
|
| 722 |
+
final_user_content = f"{final_user_content}\n{multimodal_text_content}"
|
| 723 |
+
logger.info(f"用户有文本输入,多模态内容作为补充")
|
| 724 |
+
|
| 725 |
+
# 如果最终还是没有任何内容,提供默认提示
|
| 726 |
+
if not final_user_content:
|
| 727 |
+
final_user_content = "[无输入内容]"
|
| 728 |
+
logger.warning("用户没有提供任何输入内容(文本、音频或图像)")
|
| 729 |
+
|
| 730 |
+
logger.info(f"最终用户输入内容: {final_user_content}")
|
| 731 |
+
|
| 732 |
+
# 1. 更新聊天记录 (用户部分) - 使用最终确定的用户内容
|
| 733 |
+
if not ch_history: ch_history = []
|
| 734 |
+
ch_history.append({"role": "user", "content": final_user_content})
|
| 735 |
+
yield ch_history, []
|
| 736 |
+
|
| 737 |
+
# 2. 流式生成机器人回复并更新聊天记录
|
| 738 |
+
formatted_hist_for_respond = []
|
| 739 |
+
temp_user_msg_for_hist = None
|
| 740 |
+
for item_hist in ch_history[:-1]:
|
| 741 |
+
if item_hist["role"] == "user":
|
| 742 |
+
temp_user_msg_for_hist = item_hist["content"]
|
| 743 |
+
elif item_hist["role"] == "assistant" and temp_user_msg_for_hist is not None:
|
| 744 |
+
formatted_hist_for_respond.append((temp_user_msg_for_hist, item_hist["content"]))
|
| 745 |
+
temp_user_msg_for_hist = None
|
| 746 |
+
elif item_hist["role"] == "assistant" and temp_user_msg_for_hist is None:
|
| 747 |
+
formatted_hist_for_respond.append(("", item_hist["content"]))
|
| 748 |
+
|
| 749 |
+
ch_history.append({"role": "assistant", "content": ""})
|
| 750 |
+
|
| 751 |
+
full_bot_response = ""
|
| 752 |
+
# 使用最终确定的用户内容进行对话
|
| 753 |
+
for bot_response_token, _ in respond(final_user_content, formatted_hist_for_respond, sys_msg, max_t, temp, t_p, audio_f, image_f):
|
| 754 |
+
full_bot_response = bot_response_token
|
| 755 |
+
ch_history[-1]["content"] = full_bot_response
|
| 756 |
+
yield ch_history, []
|
| 757 |
+
|
| 758 |
+
# 3. 生成 ToDoList - 使用最终确定的用户内容
|
| 759 |
+
text_for_todo = final_user_content
|
| 760 |
+
|
| 761 |
+
# 添加日志:输出用于ToDo生成的内容
|
| 762 |
+
logger.info(f"用于ToDo生成的内容: {text_for_todo}")
|
| 763 |
+
current_todos_list = []
|
| 764 |
+
|
| 765 |
+
filtered_result = filter_message_with_llm(text_for_todo)
|
| 766 |
+
|
| 767 |
+
if isinstance(filtered_result, dict) and "error" in filtered_result:
|
| 768 |
+
current_todos_list = [["Error", filtered_result['error'], "Filter Failed"]]
|
| 769 |
+
elif isinstance(filtered_result, dict) and filtered_result.get("分类") == "其他":
|
| 770 |
+
current_todos_list = [["Info", "消息被归类为 '其他',无需生成 ToDo。", "Filtered"]]
|
| 771 |
+
elif isinstance(filtered_result, list):
|
| 772 |
+
category = None
|
| 773 |
+
|
| 774 |
+
if not filtered_result:
|
| 775 |
+
if text_for_todo:
|
| 776 |
+
msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
|
| 777 |
+
current_todos_list = generate_todolist_from_text(text_for_todo, msg_id_todo)
|
| 778 |
+
yield ch_history, current_todos_list
|
| 779 |
+
return
|
| 780 |
+
|
| 781 |
+
valid_item = None
|
| 782 |
+
for item in filtered_result:
|
| 783 |
+
if isinstance(item, dict):
|
| 784 |
+
valid_item = item
|
| 785 |
+
if "分类" in item:
|
| 786 |
+
category = item["分类"]
|
| 787 |
+
break
|
| 788 |
+
|
| 789 |
+
if valid_item is None:
|
| 790 |
+
if text_for_todo:
|
| 791 |
+
msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
|
| 792 |
+
current_todos_list = generate_todolist_from_text(text_for_todo, msg_id_todo)
|
| 793 |
+
yield ch_history, current_todos_list
|
| 794 |
+
return
|
| 795 |
+
|
| 796 |
+
if category == "其他":
|
| 797 |
+
current_todos_list = [["Info", "消息被归类为 '其他',无需生成 ToDo。", "Filtered"]]
|
| 798 |
+
else:
|
| 799 |
+
if text_for_todo:
|
| 800 |
+
msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
|
| 801 |
+
current_todos_list = generate_todolist_from_text(text_for_todo, msg_id_todo)
|
| 802 |
+
else:
|
| 803 |
+
if text_for_todo:
|
| 804 |
+
msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
|
| 805 |
+
current_todos_list = generate_todolist_from_text(text_for_todo, msg_id_todo)
|
| 806 |
+
|
| 807 |
+
yield ch_history, current_todos_list
|
| 808 |
+
|
| 809 |
+
submit_btn.click(
|
| 810 |
+
handle_submit,
|
| 811 |
+
[msg, chatbot, system_msg, max_tokens_slider, temperature_slider, top_p_slider, audio_input, image_input],
|
| 812 |
+
[chatbot, todolist_df]
|
| 813 |
+
)
|
| 814 |
+
msg.submit(
|
| 815 |
+
handle_submit,
|
| 816 |
+
[msg, chatbot, system_msg, max_tokens_slider, temperature_slider, top_p_slider, audio_input, image_input],
|
| 817 |
+
[chatbot, todolist_df]
|
| 818 |
+
)
|
| 819 |
+
|
| 820 |
+
def clear_all():
|
| 821 |
+
"""清除所有聊天记录和待办事项"""
|
| 822 |
+
return None, None, ""
|
| 823 |
+
clear_btn.click(clear_all, None, [chatbot, todolist_df, msg], queue=False)
|
| 824 |
+
#多模态标签也
|
| 825 |
+
with gr.Tab("Audio/Image Processing (Original)"):
|
| 826 |
+
gr.Markdown("## 处理音频和图片")
|
| 827 |
+
audio_processor = gr.Audio(label="上传音频", type="numpy")
|
| 828 |
+
image_processor = gr.Image(label="上传图片", type="numpy")
|
| 829 |
+
process_btn = gr.Button("处理", variant="primary")
|
| 830 |
+
audio_output = gr.Textbox(label="音频信息")
|
| 831 |
+
image_output = gr.Textbox(label="图片信息")
|
| 832 |
+
|
| 833 |
+
process_btn.click(
|
| 834 |
+
process,
|
| 835 |
+
inputs=[audio_processor, image_processor],
|
| 836 |
+
outputs=[audio_output, image_output]
|
| 837 |
+
)
|
| 838 |
+
|
| 839 |
+
if __name__ == "__main__":
|
| 840 |
+
app.launch(debug=False)
|
audio_127.0.0.1.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c4cca96c289e5acdfd9d8e926bb40674e170374878d57e4d3c3f5aca3039bec8
|
| 3 |
+
size 1830956
|
image_127.0.0.1.jpg
ADDED
|
requirements.txt
CHANGED
|
@@ -1,4 +1,8 @@
|
|
| 1 |
-
gradio
|
| 2 |
-
requests
|
| 3 |
-
pathlib
|
| 4 |
-
python-dateutil
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
requests
|
| 3 |
+
pathlib
|
| 4 |
+
python-dateutil
|
| 5 |
+
Pillow
|
| 6 |
+
numpy
|
| 7 |
+
wave
|
| 8 |
+
azure-ai-inference
|
se_app.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from huggingface_hub import InferenceClient
|
| 3 |
+
import os
|
| 4 |
+
import numpy as np
|
| 5 |
+
from scipy.io.wavfile import write as write_wav
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from tools import audio_to_str, image_to_str # 导入tools.py中的方法
|
| 8 |
+
|
| 9 |
+
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
|
| 10 |
+
|
| 11 |
+
# 指定保存文件的相对路径
|
| 12 |
+
SAVE_DIR = 'download' # 相对路径
|
| 13 |
+
os.makedirs(SAVE_DIR, exist_ok=True) # 确保目录存在
|
| 14 |
+
|
| 15 |
+
def get_client_ip(request: gr.Request, debug_mode=False):
|
| 16 |
+
"""获取客户端真实IP地址"""
|
| 17 |
+
if request:
|
| 18 |
+
# 从请求头中获取真实IP(考虑代理情况)
|
| 19 |
+
x_forwarded_for = request.headers.get("x-forwarded-for", "")
|
| 20 |
+
if x_forwarded_for:
|
| 21 |
+
client_ip = x_forwarded_for.split(",")[0]
|
| 22 |
+
else:
|
| 23 |
+
client_ip = request.client.host
|
| 24 |
+
if debug_mode:
|
| 25 |
+
print(f"Debug: Client IP detected as {client_ip}")
|
| 26 |
+
return client_ip
|
| 27 |
+
return "unknown"
|
| 28 |
+
|
| 29 |
+
def save_audio(audio, filename):
|
| 30 |
+
"""保存音频为.wav文件"""
|
| 31 |
+
sample_rate, audio_data = audio
|
| 32 |
+
write_wav(filename, sample_rate, audio_data)
|
| 33 |
+
|
| 34 |
+
def save_image(image, filename):
|
| 35 |
+
"""保存图片为.jpg文件"""
|
| 36 |
+
img = Image.fromarray(image.astype('uint8'))
|
| 37 |
+
img.save(filename)
|
| 38 |
+
|
| 39 |
+
def process(audio, image, text, request: gr.Request):
|
| 40 |
+
"""处理语音、图片和文本的示例函数"""
|
| 41 |
+
client_ip = get_client_ip(request, True)
|
| 42 |
+
print(f"Processing request from IP: {client_ip}")
|
| 43 |
+
|
| 44 |
+
audio_info = "未收到音频"
|
| 45 |
+
image_info = "未收到图片"
|
| 46 |
+
text_info = "未收到文本"
|
| 47 |
+
audio_filename = None
|
| 48 |
+
image_filename = None
|
| 49 |
+
audio_text = ""
|
| 50 |
+
image_text = ""
|
| 51 |
+
|
| 52 |
+
if audio is not None:
|
| 53 |
+
sample_rate, audio_data = audio
|
| 54 |
+
audio_info = f"音频采样率: {sample_rate}Hz, 数据长度: {len(audio_data)}"
|
| 55 |
+
# 保存音频为.wav文件
|
| 56 |
+
audio_filename = os.path.join(SAVE_DIR, f"audio_{client_ip}.wav")
|
| 57 |
+
save_audio(audio, audio_filename)
|
| 58 |
+
print(f"Audio saved as {audio_filename}")
|
| 59 |
+
# 调用tools.py中的audio_to_str方法处理音频
|
| 60 |
+
audio_text = audio_to_str("33c1b63d", "40bf7cd82e31ace30a9cfb76309a43a3", "OTY1YzIyZWM3YTg0OWZiMGE2ZjA2ZmE4", audio_filename)
|
| 61 |
+
if audio_text:
|
| 62 |
+
print(f"Audio text: {audio_text}")
|
| 63 |
+
else:
|
| 64 |
+
print("Audio processing failed")
|
| 65 |
+
|
| 66 |
+
if image is not None:
|
| 67 |
+
image_info = f"图片尺寸: {image.shape}"
|
| 68 |
+
# 保存图片为.jpg文件
|
| 69 |
+
image_filename = os.path.join(SAVE_DIR, f"image_{client_ip}.jpg")
|
| 70 |
+
save_image(image, image_filename)
|
| 71 |
+
print(f"Image saved as {image_filename}")
|
| 72 |
+
# 调用tools.py中的image_to_str方法处理图片
|
| 73 |
+
image_text = image_to_str(endpoint="https://ai-siyuwang5414995ai361208251338.cognitiveservices.azure.com/", key="45PYY2Av9CdMCveAjVG43MGKrnHzSxdiFTK9mWBgrOsMAHavxKj0JQQJ99BDACHYHv6XJ3w3AAAAACOGeVpQ", unused_param=None, file_path=image_filename)
|
| 74 |
+
if image_text:
|
| 75 |
+
print(f"Image text: {image_text}")
|
| 76 |
+
else:
|
| 77 |
+
print("Image processing failed")
|
| 78 |
+
|
| 79 |
+
if text:
|
| 80 |
+
text_info = f"接收到文本: {text}"
|
| 81 |
+
|
| 82 |
+
return audio_info, image_info, text_info, audio_text, image_text
|
| 83 |
+
|
| 84 |
+
# 创建自定义的聊天界面
|
| 85 |
+
with gr.Blocks() as app:
|
| 86 |
+
gr.Markdown("# ToDoAgent Multi-Modal Interface")
|
| 87 |
+
|
| 88 |
+
# 创建两个标签页
|
| 89 |
+
with gr.Tab("Chat"):
|
| 90 |
+
# 修复Chatbot类型警告
|
| 91 |
+
chatbot = gr.Chatbot(height=500, type="messages")
|
| 92 |
+
|
| 93 |
+
msg = gr.Textbox(label="输入消息", placeholder="输入您的问题...")
|
| 94 |
+
|
| 95 |
+
# 上传区域
|
| 96 |
+
with gr.Row():
|
| 97 |
+
audio_input = gr.Audio(label="上传语音", type="numpy", sources=["upload", "microphone"])
|
| 98 |
+
image_input = gr.Image(label="上传图片", type="numpy")
|
| 99 |
+
|
| 100 |
+
# 设置区域
|
| 101 |
+
with gr.Accordion("高级设置", open=False):
|
| 102 |
+
system_msg = gr.Textbox(value="You are a friendly Chatbot.", label="系统提示")
|
| 103 |
+
max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="最大生成长度")
|
| 104 |
+
temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="温度")
|
| 105 |
+
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p")
|
| 106 |
+
|
| 107 |
+
# 提交按钮
|
| 108 |
+
submit_btn = gr.Button("发送", variant="primary")
|
| 109 |
+
|
| 110 |
+
# 清除按钮
|
| 111 |
+
clear = gr.Button("清除聊天")
|
| 112 |
+
|
| 113 |
+
# 事件处理
|
| 114 |
+
def user(user_message, chat_history):
|
| 115 |
+
return "", chat_history + [{"role": "user", "content": user_message}]
|
| 116 |
+
#新增多模态处理--1
|
| 117 |
+
def respond(message, chat_history, system_message, max_tokens, temperature, top_p, audio=None, image=None, text=None, request=None):
|
| 118 |
+
"""生成响应的函数"""
|
| 119 |
+
# 处理多模态输入
|
| 120 |
+
multimodal_content = ""
|
| 121 |
+
if audio is not None:
|
| 122 |
+
try:
|
| 123 |
+
audio_filename = os.path.join(SAVE_DIR, "temp_audio.wav")
|
| 124 |
+
save_audio(audio, audio_filename)
|
| 125 |
+
audio_text = audio_to_str("33c1b63d", "40bf7cd82e31ace30a9cfb76309a43a3", "OTY1YzIyZWM3YTg0OWZiMGE2ZjA2ZmE4", audio_filename)
|
| 126 |
+
if audio_text:
|
| 127 |
+
multimodal_content += f"音频内容: {audio_text}\n"
|
| 128 |
+
except Exception as e:
|
| 129 |
+
print(f"Audio processing error: {e}")
|
| 130 |
+
|
| 131 |
+
if image is not None:
|
| 132 |
+
try:
|
| 133 |
+
image_filename = os.path.join(SAVE_DIR, "temp_image.jpg")
|
| 134 |
+
save_image(image, image_filename)
|
| 135 |
+
image_text = image_to_str(endpoint="https://ai-siyuwang5414995ai361208251338.cognitiveservices.azure.com/", key="45PYY2Av9CdMCveAjVG43MGKrnHzSxdiFTK9mWBgrOsMAHavxKj0JQQJ99BDACHYHv6XJ3w3AAAAACOGeVpQ", unused_param=None, file_path=image_filename)
|
| 136 |
+
if image_text:
|
| 137 |
+
multimodal_content += f"图片内容: {image_text}\n"
|
| 138 |
+
except Exception as e:
|
| 139 |
+
print(f"Image processing error: {e}")
|
| 140 |
+
|
| 141 |
+
# 组合最终消息
|
| 142 |
+
final_message = message
|
| 143 |
+
if multimodal_content:
|
| 144 |
+
final_message = f"{message}\n\n{multimodal_content}"
|
| 145 |
+
|
| 146 |
+
# 构建消息历史
|
| 147 |
+
messages = [{"role": "system", "content": system_message}]
|
| 148 |
+
for chat in chat_history:
|
| 149 |
+
if isinstance(chat, dict) and "role" in chat and "content" in chat:
|
| 150 |
+
messages.append(chat)
|
| 151 |
+
|
| 152 |
+
messages.append({"role": "user", "content": final_message})
|
| 153 |
+
|
| 154 |
+
# 调用HuggingFace API
|
| 155 |
+
try:
|
| 156 |
+
response = client.chat_completion(
|
| 157 |
+
messages,
|
| 158 |
+
max_tokens=max_tokens,
|
| 159 |
+
stream=True,
|
| 160 |
+
temperature=temperature,
|
| 161 |
+
top_p=top_p,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
partial_message = ""
|
| 165 |
+
for token in response:
|
| 166 |
+
if token.choices[0].delta.content is not None:
|
| 167 |
+
partial_message += token.choices[0].delta.content
|
| 168 |
+
yield partial_message
|
| 169 |
+
except Exception as e:
|
| 170 |
+
yield f"抱歉,生成响应时出现错误: {str(e)}"
|
| 171 |
+
|
| 172 |
+
def bot(chat_history, system_message, max_tokens, temperature, top_p, audio, image, text):
|
| 173 |
+
# 检查chat_history是否为空
|
| 174 |
+
if not chat_history or len(chat_history) == 0:
|
| 175 |
+
return
|
| 176 |
+
|
| 177 |
+
# 获取最后一条用户消息
|
| 178 |
+
last_message = chat_history[-1]
|
| 179 |
+
if not last_message or not isinstance(last_message, dict) or "content" not in last_message:
|
| 180 |
+
return
|
| 181 |
+
|
| 182 |
+
user_message = last_message["content"]
|
| 183 |
+
|
| 184 |
+
# 生成响应
|
| 185 |
+
bot_response = ""
|
| 186 |
+
for response in respond(
|
| 187 |
+
user_message,
|
| 188 |
+
chat_history[:-1],
|
| 189 |
+
system_message,
|
| 190 |
+
max_tokens,
|
| 191 |
+
temperature,
|
| 192 |
+
top_p,
|
| 193 |
+
audio,
|
| 194 |
+
image,
|
| 195 |
+
text
|
| 196 |
+
):
|
| 197 |
+
bot_response = response
|
| 198 |
+
# 添加助手回复到聊天历史
|
| 199 |
+
updated_history = chat_history + [{"role": "assistant", "content": bot_response}]
|
| 200 |
+
yield updated_history
|
| 201 |
+
|
| 202 |
+
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
| 203 |
+
bot, [chatbot, system_msg, max_tokens, temperature, top_p, audio_input, image_input, msg], chatbot
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
submit_btn.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
| 207 |
+
bot, [chatbot, system_msg, max_tokens, temperature, top_p, audio_input, image_input, msg], chatbot
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
clear.click(lambda: None, None, chatbot, queue=False)
|
| 211 |
+
|
| 212 |
+
with gr.Tab("Audio/Image Processing"):
|
| 213 |
+
gr.Markdown("## 处理音频和图片")
|
| 214 |
+
audio_processor = gr.Audio(label="上传音频", type="numpy")
|
| 215 |
+
image_processor = gr.Image(label="上传图片", type="numpy")
|
| 216 |
+
text_input = gr.Textbox(label="输入文本")
|
| 217 |
+
process_btn = gr.Button("处理", variant="primary")
|
| 218 |
+
audio_output = gr.Textbox(label="音频信息")
|
| 219 |
+
image_output = gr.Textbox(label="图片信息")
|
| 220 |
+
text_output = gr.Textbox(label="文本信息")
|
| 221 |
+
audio_text_output = gr.Textbox(label="音频转文字结果")
|
| 222 |
+
image_text_output = gr.Textbox(label="图片转文字结果")
|
| 223 |
+
|
| 224 |
+
# 修改后的处理函数调用
|
| 225 |
+
process_btn.click(
|
| 226 |
+
process,
|
| 227 |
+
inputs=[audio_processor, image_processor, text_input],
|
| 228 |
+
outputs=[audio_output, image_output, text_output, audio_text_output, image_text_output]
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
if __name__ == "__main__":
|
| 232 |
+
app.launch()
|
temp_audio.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8a873051a6c784789c314ab829772eac2446271337f54d58db48e921e81ab71e
|
| 3 |
+
size 710700
|
todogen_LLM_config.yaml
CHANGED
|
@@ -38,4 +38,14 @@ HF_CONFIG_PATH:
|
|
| 38 |
openai_filter:
|
| 39 |
base_url_filter: https://aihubmix.com/v1
|
| 40 |
api_key_filter: sk-BSNyITzJBSSgfFdJ792b66C7789c479cA7Ec1e36FfB343A1
|
| 41 |
-
model_filter: gpt-4o-mini
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
openai_filter:
|
| 39 |
base_url_filter: https://aihubmix.com/v1
|
| 40 |
api_key_filter: sk-BSNyITzJBSSgfFdJ792b66C7789c479cA7Ec1e36FfB343A1
|
| 41 |
+
model_filter: gpt-4o-mini
|
| 42 |
+
|
| 43 |
+
xunfei:
|
| 44 |
+
appid: 33c1b63d
|
| 45 |
+
apikey: 40bf7cd82e31ace30a9cfb76309a43a3
|
| 46 |
+
apisecret: OTY1YzIyZWM3YTg0OWZiMGE2ZjA2ZmE4
|
| 47 |
+
|
| 48 |
+
azure_speech:
|
| 49 |
+
key: 45PYY2Av9CdMCveAjVG43MGKrnHzSxdiFTK9mWBgrOsMAHavxKj0JQQJ99BDACHYHv6XJ3w3AAAAACOGeVpQ
|
| 50 |
+
region: eastus2
|
| 51 |
+
endpoint: https://eastus2.stt.speech.microsoft.com
|
tools.py
ADDED
|
@@ -0,0 +1,828 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding:utf-8 -*-
|
| 3 |
+
import os
|
| 4 |
+
import datetime
|
| 5 |
+
import re
|
| 6 |
+
import time
|
| 7 |
+
import traceback
|
| 8 |
+
import math
|
| 9 |
+
from urllib.parse import urlparse
|
| 10 |
+
from urllib3 import encode_multipart_formdata
|
| 11 |
+
from wsgiref.handlers import format_date_time
|
| 12 |
+
from time import mktime
|
| 13 |
+
import hashlib
|
| 14 |
+
import base64
|
| 15 |
+
import hmac
|
| 16 |
+
from urllib.parse import urlencode
|
| 17 |
+
import json
|
| 18 |
+
import requests
|
| 19 |
+
import azure.cognitiveservices.speech as speechsdk
|
| 20 |
+
|
| 21 |
+
# 常量定义
|
| 22 |
+
LFASR_HOST = "http://upload-ost-api.xfyun.cn/file" # 文件上传Host
|
| 23 |
+
API_INIT = "/mpupload/init" # 初始化接口
|
| 24 |
+
API_UPLOAD = "/upload" # 上传接口
|
| 25 |
+
API_CUT = "/mpupload/upload" # 分片上传接口
|
| 26 |
+
API_CUT_COMPLETE = "/mpupload/complete" # 分片完成接口
|
| 27 |
+
API_CUT_CANCEL = "/mpupload/cancel" # 分片取消接口
|
| 28 |
+
FILE_PIECE_SIZE = 5242880 # 文件分片大小5M
|
| 29 |
+
PRO_CREATE_URI = "/v2/ost/pro_create"
|
| 30 |
+
QUERY_URI = "/v2/ost/query"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# 文件上传类
|
| 34 |
+
class FileUploader:
|
| 35 |
+
def __init__(self, app_id, api_key, api_secret, upload_file_path):
|
| 36 |
+
self.app_id = app_id
|
| 37 |
+
self.api_key = api_key
|
| 38 |
+
self.api_secret = api_secret
|
| 39 |
+
self.upload_file_path = upload_file_path
|
| 40 |
+
|
| 41 |
+
def get_request_id(self):
|
| 42 |
+
"""生成请求ID"""
|
| 43 |
+
return time.strftime("%Y%m%d%H%M")
|
| 44 |
+
|
| 45 |
+
def hashlib_256(self, data):
|
| 46 |
+
"""计算 SHA256 哈希"""
|
| 47 |
+
m = hashlib.sha256(bytes(data.encode(encoding="utf-8"))).digest()
|
| 48 |
+
digest = "SHA-256=" + base64.b64encode(m).decode(encoding="utf-8")
|
| 49 |
+
return digest
|
| 50 |
+
|
| 51 |
+
def assemble_auth_header(self, request_url, file_data_type, method="", body=""):
|
| 52 |
+
"""组装鉴权头部"""
|
| 53 |
+
u = urlparse(request_url)
|
| 54 |
+
host = u.hostname
|
| 55 |
+
path = u.path
|
| 56 |
+
now = datetime.datetime.now()
|
| 57 |
+
date = format_date_time(mktime(now.timetuple()))
|
| 58 |
+
digest = "SHA256=" + self.hashlib_256("")
|
| 59 |
+
signature_origin = "host: {}\ndate: {}\n{} {} HTTP/1.1\ndigest: {}".format(
|
| 60 |
+
host, date, method, path, digest
|
| 61 |
+
)
|
| 62 |
+
signature_sha = hmac.new(
|
| 63 |
+
self.api_secret.encode("utf-8"),
|
| 64 |
+
signature_origin.encode("utf-8"),
|
| 65 |
+
digestmod=hashlib.sha256,
|
| 66 |
+
).digest()
|
| 67 |
+
signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8")
|
| 68 |
+
authorization = 'api_key="%s", algorithm="%s", headers="%s", signature="%s"' % (
|
| 69 |
+
self.api_key,
|
| 70 |
+
"hmac-sha256",
|
| 71 |
+
"host date request-line digest",
|
| 72 |
+
signature_sha,
|
| 73 |
+
)
|
| 74 |
+
headers = {
|
| 75 |
+
"host": host,
|
| 76 |
+
"date": date,
|
| 77 |
+
"authorization": authorization,
|
| 78 |
+
"digest": digest,
|
| 79 |
+
"content-type": file_data_type,
|
| 80 |
+
}
|
| 81 |
+
return headers
|
| 82 |
+
|
| 83 |
+
def call_api(self, url, file_data, file_data_type):
|
| 84 |
+
"""调用POST API接口"""
|
| 85 |
+
headers = self.assemble_auth_header(
|
| 86 |
+
url, file_data_type, method="POST", body=file_data
|
| 87 |
+
)
|
| 88 |
+
try:
|
| 89 |
+
resp = requests.post(url, headers=headers, data=file_data, timeout=8)
|
| 90 |
+
print("上传状态:", resp.status_code, resp.text)
|
| 91 |
+
return resp.json()
|
| 92 |
+
except Exception as e:
|
| 93 |
+
print("上传失败!Exception :%s" % e)
|
| 94 |
+
return None
|
| 95 |
+
|
| 96 |
+
def upload_cut_complete(self, upload_id):
|
| 97 |
+
"""分块上传完成"""
|
| 98 |
+
body_dict = {
|
| 99 |
+
"app_id": self.app_id,
|
| 100 |
+
"request_id": self.get_request_id(),
|
| 101 |
+
"upload_id": upload_id,
|
| 102 |
+
}
|
| 103 |
+
file_data_type = "application/json"
|
| 104 |
+
url = LFASR_HOST + API_CUT_COMPLETE
|
| 105 |
+
response = self.call_api(url, json.dumps(body_dict), file_data_type)
|
| 106 |
+
if response and "data" in response and "url" in response["data"]:
|
| 107 |
+
file_url = response["data"]["url"]
|
| 108 |
+
print("任务上传结束")
|
| 109 |
+
return file_url
|
| 110 |
+
else:
|
| 111 |
+
print("分片上传完成失败", response)
|
| 112 |
+
return None
|
| 113 |
+
|
| 114 |
+
def upload_file(self):
|
| 115 |
+
"""上传文件,根据文件大小选择分片或普通上传"""
|
| 116 |
+
file_total_size = os.path.getsize(self.upload_file_path)
|
| 117 |
+
if file_total_size < 31457280: # 30MB
|
| 118 |
+
print("-----不使用分块上传-----")
|
| 119 |
+
return self.simple_upload()
|
| 120 |
+
else:
|
| 121 |
+
print("-----使用分块上传-----")
|
| 122 |
+
return self.multipart_upload()
|
| 123 |
+
|
| 124 |
+
def simple_upload(self):
|
| 125 |
+
"""简单上传文件"""
|
| 126 |
+
try:
|
| 127 |
+
with open(self.upload_file_path, mode="rb") as f:
|
| 128 |
+
file = {
|
| 129 |
+
"data": (self.upload_file_path, f.read()),
|
| 130 |
+
"app_id": self.app_id,
|
| 131 |
+
"request_id": self.get_request_id(),
|
| 132 |
+
}
|
| 133 |
+
encode_data = encode_multipart_formdata(file)
|
| 134 |
+
file_data = encode_data[0]
|
| 135 |
+
file_data_type = encode_data[1]
|
| 136 |
+
url = LFASR_HOST + API_UPLOAD
|
| 137 |
+
response = self.call_api(url, file_data, file_data_type)
|
| 138 |
+
if response and "data" in response and "url" in response["data"]:
|
| 139 |
+
return response["data"]["url"]
|
| 140 |
+
else:
|
| 141 |
+
print("简单上传失败", response)
|
| 142 |
+
return None
|
| 143 |
+
except FileNotFoundError:
|
| 144 |
+
print("文件未找到:", self.upload_file_path)
|
| 145 |
+
return None
|
| 146 |
+
|
| 147 |
+
def multipart_upload(self):
|
| 148 |
+
"""分片上传文件"""
|
| 149 |
+
upload_id = self.prepare_upload()
|
| 150 |
+
if not upload_id:
|
| 151 |
+
return None
|
| 152 |
+
|
| 153 |
+
if not self.do_upload(upload_id):
|
| 154 |
+
return None
|
| 155 |
+
|
| 156 |
+
file_url = self.upload_cut_complete(upload_id)
|
| 157 |
+
print("分片上传地址:", file_url)
|
| 158 |
+
return file_url
|
| 159 |
+
|
| 160 |
+
def prepare_upload(self):
|
| 161 |
+
"""预处理,获取upload_id"""
|
| 162 |
+
body_dict = {
|
| 163 |
+
"app_id": self.app_id,
|
| 164 |
+
"request_id": self.get_request_id(),
|
| 165 |
+
}
|
| 166 |
+
url = LFASR_HOST + API_INIT
|
| 167 |
+
file_data_type = "application/json"
|
| 168 |
+
response = self.call_api(url, json.dumps(body_dict), file_data_type)
|
| 169 |
+
if response and "data" in response and "upload_id" in response["data"]:
|
| 170 |
+
return response["data"]["upload_id"]
|
| 171 |
+
else:
|
| 172 |
+
print("预处理失败", response)
|
| 173 |
+
return None
|
| 174 |
+
|
| 175 |
+
def do_upload(self, upload_id):
|
| 176 |
+
"""执行分片上传"""
|
| 177 |
+
file_total_size = os.path.getsize(self.upload_file_path)
|
| 178 |
+
chunk_size = FILE_PIECE_SIZE
|
| 179 |
+
chunks = math.ceil(file_total_size / chunk_size)
|
| 180 |
+
request_id = self.get_request_id()
|
| 181 |
+
slice_id = 1
|
| 182 |
+
|
| 183 |
+
print(
|
| 184 |
+
"文件:",
|
| 185 |
+
self.upload_file_path,
|
| 186 |
+
" 文件大小:",
|
| 187 |
+
file_total_size,
|
| 188 |
+
" 分块大小:",
|
| 189 |
+
chunk_size,
|
| 190 |
+
" 分块数:",
|
| 191 |
+
chunks,
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
with open(self.upload_file_path, mode="rb") as content:
|
| 195 |
+
while slice_id <= chunks:
|
| 196 |
+
current_size = min(
|
| 197 |
+
chunk_size, file_total_size - (slice_id - 1) * chunk_size
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
file = {
|
| 201 |
+
"data": (self.upload_file_path, content.read(current_size)),
|
| 202 |
+
"app_id": self.app_id,
|
| 203 |
+
"request_id": request_id,
|
| 204 |
+
"upload_id": upload_id,
|
| 205 |
+
"slice_id": slice_id,
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
encode_data = encode_multipart_formdata(file)
|
| 209 |
+
file_data = encode_data[0]
|
| 210 |
+
file_data_type = encode_data[1]
|
| 211 |
+
url = LFASR_HOST + API_CUT
|
| 212 |
+
|
| 213 |
+
resp = self.call_api(url, file_data, file_data_type)
|
| 214 |
+
count = 0
|
| 215 |
+
while not resp and (count < 3):
|
| 216 |
+
print("上传重试")
|
| 217 |
+
resp = self.call_api(url, file_data, file_data_type)
|
| 218 |
+
count = count + 1
|
| 219 |
+
time.sleep(1)
|
| 220 |
+
if not resp:
|
| 221 |
+
print("分片上传失败")
|
| 222 |
+
return False
|
| 223 |
+
slice_id += 1
|
| 224 |
+
|
| 225 |
+
return True
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class ResultExtractor:
|
| 229 |
+
def __init__(self, appid, apikey, apisecret):
|
| 230 |
+
# POST 请求相关参数
|
| 231 |
+
self.Host = "ost-api.xfyun.cn"
|
| 232 |
+
self.RequestUriCreate = PRO_CREATE_URI
|
| 233 |
+
self.RequestUriQuery = QUERY_URI
|
| 234 |
+
# 设置 URL
|
| 235 |
+
if re.match(r"^\d", self.Host):
|
| 236 |
+
self.urlCreate = "http://" + self.Host + self.RequestUriCreate
|
| 237 |
+
self.urlQuery = "http://" + self.Host + self.RequestUriQuery
|
| 238 |
+
else:
|
| 239 |
+
self.urlCreate = "https://" + self.Host + self.RequestUriCreate
|
| 240 |
+
self.urlQuery = "https://" + self.Host + self.RequestUriQuery
|
| 241 |
+
self.HttpMethod = "POST"
|
| 242 |
+
self.APPID = appid
|
| 243 |
+
self.Algorithm = "hmac-sha256"
|
| 244 |
+
self.HttpProto = "HTTP/1.1"
|
| 245 |
+
self.UserName = apikey
|
| 246 |
+
self.Secret = apisecret
|
| 247 |
+
|
| 248 |
+
# 设置当前时间
|
| 249 |
+
cur_time_utc = datetime.datetime.now(datetime.timezone.utc)
|
| 250 |
+
self.Date = self.httpdate(cur_time_utc)
|
| 251 |
+
|
| 252 |
+
# 设置测试音频文件参数
|
| 253 |
+
self.BusinessArgsCreate = {
|
| 254 |
+
"language": "zh_cn",
|
| 255 |
+
"accent": "mandarin",
|
| 256 |
+
"domain": "pro_ost_ed",
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
def img_read(self, path):
|
| 260 |
+
with open(path, "rb") as fo:
|
| 261 |
+
return fo.read()
|
| 262 |
+
|
| 263 |
+
def hashlib_256(self, res):
|
| 264 |
+
m = hashlib.sha256(bytes(res.encode(encoding="utf-8"))).digest()
|
| 265 |
+
result = "SHA-256=" + base64.b64encode(m).decode(encoding="utf-8")
|
| 266 |
+
return result
|
| 267 |
+
|
| 268 |
+
def httpdate(self, dt):
|
| 269 |
+
weekday = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"][dt.weekday()]
|
| 270 |
+
month = [
|
| 271 |
+
"Jan",
|
| 272 |
+
"Feb",
|
| 273 |
+
"Mar",
|
| 274 |
+
"Apr",
|
| 275 |
+
"May",
|
| 276 |
+
"Jun",
|
| 277 |
+
"Jul",
|
| 278 |
+
"Aug",
|
| 279 |
+
"Sep",
|
| 280 |
+
"Oct",
|
| 281 |
+
"Nov",
|
| 282 |
+
"Dec",
|
| 283 |
+
][dt.month - 1]
|
| 284 |
+
return "%s, %02d %s %04d %02d:%02d:%02d GMT" % (
|
| 285 |
+
weekday,
|
| 286 |
+
dt.day,
|
| 287 |
+
month,
|
| 288 |
+
dt.year,
|
| 289 |
+
dt.hour,
|
| 290 |
+
dt.minute,
|
| 291 |
+
dt.second,
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
def generateSignature(self, digest, uri):
|
| 295 |
+
signature_str = "host: " + self.Host + "\n"
|
| 296 |
+
signature_str += "date: " + self.Date + "\n"
|
| 297 |
+
signature_str += self.HttpMethod + " " + uri + " " + self.HttpProto + "\n"
|
| 298 |
+
signature_str += "digest: " + digest
|
| 299 |
+
signature = hmac.new(
|
| 300 |
+
bytes(self.Secret.encode("utf-8")),
|
| 301 |
+
bytes(signature_str.encode("utf-8")),
|
| 302 |
+
digestmod=hashlib.sha256,
|
| 303 |
+
).digest()
|
| 304 |
+
result = base64.b64encode(signature)
|
| 305 |
+
return result.decode(encoding="utf-8")
|
| 306 |
+
|
| 307 |
+
def init_header(self, data, uri):
|
| 308 |
+
digest = self.hashlib_256(data)
|
| 309 |
+
sign = self.generateSignature(digest, uri)
|
| 310 |
+
auth_header = (
|
| 311 |
+
'api_key="%s",algorithm="%s", '
|
| 312 |
+
'headers="host date request-line digest", '
|
| 313 |
+
'signature="%s"' % (self.UserName, self.Algorithm, sign)
|
| 314 |
+
)
|
| 315 |
+
headers = {
|
| 316 |
+
"Content-Type": "application/json",
|
| 317 |
+
"Accept": "application/json",
|
| 318 |
+
"Method": "POST",
|
| 319 |
+
"Host": self.Host,
|
| 320 |
+
"Date": self.Date,
|
| 321 |
+
"Digest": digest,
|
| 322 |
+
"Authorization": auth_header,
|
| 323 |
+
}
|
| 324 |
+
return headers
|
| 325 |
+
|
| 326 |
+
def get_create_body(self, fileurl):
|
| 327 |
+
post_data = {
|
| 328 |
+
"common": {"app_id": self.APPID},
|
| 329 |
+
"business": self.BusinessArgsCreate,
|
| 330 |
+
"data": {"audio_src": "http", "audio_url": fileurl, "encoding": "raw"},
|
| 331 |
+
}
|
| 332 |
+
body = json.dumps(post_data)
|
| 333 |
+
return body
|
| 334 |
+
|
| 335 |
+
def get_query_body(self, task_id):
|
| 336 |
+
post_data = {
|
| 337 |
+
"common": {"app_id": self.APPID},
|
| 338 |
+
"business": {
|
| 339 |
+
"task_id": task_id,
|
| 340 |
+
},
|
| 341 |
+
}
|
| 342 |
+
body = json.dumps(post_data)
|
| 343 |
+
return body
|
| 344 |
+
|
| 345 |
+
def call(self, url, body, headers):
|
| 346 |
+
try:
|
| 347 |
+
response = requests.post(url, data=body, headers=headers, timeout=8)
|
| 348 |
+
status_code = response.status_code
|
| 349 |
+
if status_code != 200:
|
| 350 |
+
info = response.content
|
| 351 |
+
return info
|
| 352 |
+
else:
|
| 353 |
+
try:
|
| 354 |
+
return json.loads(response.text)
|
| 355 |
+
except json.JSONDecodeError:
|
| 356 |
+
return response.text
|
| 357 |
+
except Exception as e:
|
| 358 |
+
print("Exception :%s" % e)
|
| 359 |
+
return None
|
| 360 |
+
|
| 361 |
+
def task_create(self, fileurl):
|
| 362 |
+
body = self.get_create_body(fileurl)
|
| 363 |
+
headers_create = self.init_header(body, self.RequestUriCreate)
|
| 364 |
+
return self.call(self.urlCreate, body, headers_create)
|
| 365 |
+
|
| 366 |
+
def task_query(self, task_id):
|
| 367 |
+
query_body = self.get_query_body(task_id)
|
| 368 |
+
headers_query = self.init_header(query_body, self.RequestUriQuery)
|
| 369 |
+
return self.call(self.urlQuery, query_body, headers_query)
|
| 370 |
+
|
| 371 |
+
def extract_text(self, result):
|
| 372 |
+
"""
|
| 373 |
+
从API响应中提取文本内容
|
| 374 |
+
支持多种结果格式,增强错误处理
|
| 375 |
+
"""
|
| 376 |
+
# 调试输出:打印原始结果类型
|
| 377 |
+
print(f"\n[DEBUG] extract_text 输入类型: {type(result)}")
|
| 378 |
+
|
| 379 |
+
# 如果是字符串,尝试解析为JSON
|
| 380 |
+
if isinstance(result, str):
|
| 381 |
+
print(f"[DEBUG] 字符串内容 (前200字符): {result[:200]}")
|
| 382 |
+
try:
|
| 383 |
+
result = json.loads(result)
|
| 384 |
+
print("[DEBUG] 成功解析字符串为JSON对象")
|
| 385 |
+
except json.JSONDecodeError:
|
| 386 |
+
print("[DEBUG] 无法解析为JSON,返回原始字符串")
|
| 387 |
+
return result
|
| 388 |
+
|
| 389 |
+
# 处理字典类型的结果
|
| 390 |
+
if isinstance(result, dict):
|
| 391 |
+
print("[DEBUG] 处理字典类型结果")
|
| 392 |
+
|
| 393 |
+
# 1. 检查错误信息
|
| 394 |
+
if "code" in result and result["code"] != 0:
|
| 395 |
+
error_msg = result.get("message", "未知错误")
|
| 396 |
+
print(
|
| 397 |
+
f"[ERROR] API返回错误: code={result['code']}, message={error_msg}"
|
| 398 |
+
)
|
| 399 |
+
return f"错误: {error_msg}"
|
| 400 |
+
|
| 401 |
+
# 2. 检查直接包含文本结果的情况
|
| 402 |
+
if "result" in result and isinstance(result["result"], str):
|
| 403 |
+
print("[DEBUG] 找到直接结果字段")
|
| 404 |
+
return result["result"]
|
| 405 |
+
|
| 406 |
+
# 3. 检查lattice结构(详细结果)
|
| 407 |
+
if "lattice" in result and isinstance(result["lattice"], list):
|
| 408 |
+
print("[DEBUG] 解析lattice结构")
|
| 409 |
+
text_parts = []
|
| 410 |
+
for lattice in result["lattice"]:
|
| 411 |
+
if not isinstance(lattice, dict):
|
| 412 |
+
continue
|
| 413 |
+
|
| 414 |
+
# 获取json_1best内容
|
| 415 |
+
json_1best = lattice.get("json_1best", {})
|
| 416 |
+
if not json_1best or not isinstance(json_1best, dict):
|
| 417 |
+
continue
|
| 418 |
+
|
| 419 |
+
# 处理st字段 - 修正:st可能是字典或列表
|
| 420 |
+
st_content = json_1best.get("st")
|
| 421 |
+
st_list = []
|
| 422 |
+
if isinstance(st_content, dict):
|
| 423 |
+
st_list = [st_content] # 转为列表统一处理
|
| 424 |
+
elif isinstance(st_content, list):
|
| 425 |
+
st_list = st_content
|
| 426 |
+
|
| 427 |
+
for st in st_list:
|
| 428 |
+
if isinstance(st, str):
|
| 429 |
+
# 直接是字符串结果
|
| 430 |
+
text_parts.append(st)
|
| 431 |
+
elif isinstance(st, dict):
|
| 432 |
+
# 处理字典结构的st
|
| 433 |
+
rt = st.get("rt", [])
|
| 434 |
+
if not isinstance(rt, list):
|
| 435 |
+
continue
|
| 436 |
+
|
| 437 |
+
for item in rt:
|
| 438 |
+
if isinstance(item, dict):
|
| 439 |
+
ws_list = item.get("ws", [])
|
| 440 |
+
if isinstance(ws_list, list):
|
| 441 |
+
for ws in ws_list:
|
| 442 |
+
if isinstance(ws, dict):
|
| 443 |
+
cw_list = ws.get("cw", [])
|
| 444 |
+
if isinstance(cw_list, list):
|
| 445 |
+
for cw in cw_list:
|
| 446 |
+
if isinstance(cw, dict):
|
| 447 |
+
w = cw.get("w", "")
|
| 448 |
+
if w:
|
| 449 |
+
text_parts.append(w)
|
| 450 |
+
return "".join(text_parts)
|
| 451 |
+
|
| 452 |
+
# 4. 检查简化结构(直接包含st)
|
| 453 |
+
if "st" in result and isinstance(result["st"], list):
|
| 454 |
+
print("[DEBUG] 解析st结构")
|
| 455 |
+
text_parts = []
|
| 456 |
+
for st in result["st"]:
|
| 457 |
+
if isinstance(st, str):
|
| 458 |
+
text_parts.append(st)
|
| 459 |
+
elif isinstance(st, dict):
|
| 460 |
+
rt = st.get("rt", [])
|
| 461 |
+
if isinstance(rt, list):
|
| 462 |
+
for item in rt:
|
| 463 |
+
if isinstance(item, dict):
|
| 464 |
+
ws_list = item.get("ws", [])
|
| 465 |
+
if isinstance(ws_list, list):
|
| 466 |
+
for ws in ws_list:
|
| 467 |
+
if isinstance(ws, dict):
|
| 468 |
+
cw_list = ws.get("cw", [])
|
| 469 |
+
if isinstance(cw_list, list):
|
| 470 |
+
for cw in cw_list:
|
| 471 |
+
if isinstance(cw, dict):
|
| 472 |
+
w = cw.get("w", "")
|
| 473 |
+
if w:
|
| 474 |
+
text_parts.append(w)
|
| 475 |
+
return "".join(text_parts)
|
| 476 |
+
|
| 477 |
+
# 5. 其他未知结构
|
| 478 |
+
print("[WARNING] 无法识别的结果结构")
|
| 479 |
+
return json.dumps(result, indent=2, ensure_ascii=False)
|
| 480 |
+
|
| 481 |
+
# 6. 非字典类型结果
|
| 482 |
+
print(f"[WARNING] 非字典类型结果: {type(result)}")
|
| 483 |
+
return str(result)
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
def audio_to_str(appid, apikey, apisecret, file_path):
|
| 487 |
+
"""
|
| 488 |
+
调用讯飞开放平台接口,获取音频文件的转写结果。
|
| 489 |
+
|
| 490 |
+
参数:
|
| 491 |
+
appid (str): 讯飞开放平台的appid。
|
| 492 |
+
apikey (str): 讯飞开放平台的apikey。
|
| 493 |
+
apisecret (str): 讯飞开放平台的apisecret。
|
| 494 |
+
file_path (str): 音频文件路径。
|
| 495 |
+
|
| 496 |
+
返回值:
|
| 497 |
+
str: 转写结果文本,如果发生错误则返回None。
|
| 498 |
+
"""
|
| 499 |
+
# 检查文件是否存在
|
| 500 |
+
if not os.path.exists(file_path):
|
| 501 |
+
print(f"错误:文件 {file_path} 不存在")
|
| 502 |
+
return None
|
| 503 |
+
|
| 504 |
+
try:
|
| 505 |
+
# 1. 文件上传
|
| 506 |
+
file_uploader = FileUploader(
|
| 507 |
+
app_id=appid,
|
| 508 |
+
api_key=apikey,
|
| 509 |
+
api_secret=apisecret,
|
| 510 |
+
upload_file_path=file_path,
|
| 511 |
+
)
|
| 512 |
+
fileurl = file_uploader.upload_file()
|
| 513 |
+
if not fileurl:
|
| 514 |
+
print("文件上传失败")
|
| 515 |
+
return None
|
| 516 |
+
print("文件上传成功,fileurl:", fileurl)
|
| 517 |
+
|
| 518 |
+
# 2. 创建任务并查询结果
|
| 519 |
+
result_extractor = ResultExtractor(appid, apikey, apisecret)
|
| 520 |
+
print("\n------ 创建任务 -------")
|
| 521 |
+
create_response = result_extractor.task_create(fileurl)
|
| 522 |
+
|
| 523 |
+
# 调试输出创建响应
|
| 524 |
+
print(
|
| 525 |
+
f"[DEBUG] 创建任务响应: {json.dumps(create_response, indent=2, ensure_ascii=False)}"
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
if not isinstance(create_response, dict) or "data" not in create_response:
|
| 529 |
+
print("创建任务失败:", create_response)
|
| 530 |
+
return None
|
| 531 |
+
|
| 532 |
+
task_id = create_response["data"]["task_id"]
|
| 533 |
+
print(f"任务ID: {task_id}")
|
| 534 |
+
|
| 535 |
+
# 查询任务
|
| 536 |
+
print("\n------ 查询任务 -------")
|
| 537 |
+
print("任务转写中······")
|
| 538 |
+
max_attempts = 30
|
| 539 |
+
attempt = 0
|
| 540 |
+
|
| 541 |
+
while attempt < max_attempts:
|
| 542 |
+
result = result_extractor.task_query(task_id)
|
| 543 |
+
|
| 544 |
+
# 调试输出查询响应
|
| 545 |
+
print(f"\n[QUERY {attempt + 1}] 响应类型: {type(result)}")
|
| 546 |
+
if isinstance(result, dict):
|
| 547 |
+
print(
|
| 548 |
+
f"[QUERY {attempt + 1}] 响应内容: {json.dumps(result, indent=2, ensure_ascii=False)}"
|
| 549 |
+
)
|
| 550 |
+
else:
|
| 551 |
+
print(
|
| 552 |
+
f"[QUERY {attempt + 1}] 响应内容 (前200字符): {str(result)[:200]}"
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
# 检查响应是否有效
|
| 556 |
+
if not isinstance(result, dict):
|
| 557 |
+
print(f"无效响应类型: {type(result)}")
|
| 558 |
+
return None
|
| 559 |
+
|
| 560 |
+
# 检查API错误码
|
| 561 |
+
if "code" in result and result["code"] != 0:
|
| 562 |
+
error_msg = result.get("message", "未知错误")
|
| 563 |
+
print(f"API错误: code={result['code']}, message={error_msg}")
|
| 564 |
+
return None
|
| 565 |
+
|
| 566 |
+
# 获取任务状态
|
| 567 |
+
task_data = result.get("data", {})
|
| 568 |
+
task_status = task_data.get("task_status")
|
| 569 |
+
|
| 570 |
+
if not task_status:
|
| 571 |
+
print("响应中缺少任务状态字段")
|
| 572 |
+
print("完整响应:", json.dumps(result, indent=2, ensure_ascii=False))
|
| 573 |
+
return None
|
| 574 |
+
|
| 575 |
+
# 处理不同状态
|
| 576 |
+
if task_status in ["3", "4"]: # 任务已完成或回调完成
|
| 577 |
+
print("转写完成···")
|
| 578 |
+
|
| 579 |
+
# 提取结果
|
| 580 |
+
result_content = task_data.get("result")
|
| 581 |
+
if result_content is not None:
|
| 582 |
+
try:
|
| 583 |
+
result_text = result_extractor.extract_text(result_content)
|
| 584 |
+
print("\n转写结果:\n", result_text)
|
| 585 |
+
return result_text
|
| 586 |
+
except Exception as e:
|
| 587 |
+
print(f"\n提取文本时出错: {str(e)}")
|
| 588 |
+
print(f"错误详情:\n{traceback.format_exc()}")
|
| 589 |
+
print(
|
| 590 |
+
"原始结果内容:",
|
| 591 |
+
json.dumps(result_content, indent=2, ensure_ascii=False),
|
| 592 |
+
)
|
| 593 |
+
return None
|
| 594 |
+
else:
|
| 595 |
+
print("\n响应中缺少结果字段")
|
| 596 |
+
print("完整响应:", json.dumps(result, indent=2, ensure_ascii=False))
|
| 597 |
+
return None
|
| 598 |
+
|
| 599 |
+
elif task_status in ["1", "2"]: # 任务待处理或处理中
|
| 600 |
+
print(
|
| 601 |
+
f"任务状态:{task_status},等待中... (尝试 {attempt + 1}/{max_attempts})"
|
| 602 |
+
)
|
| 603 |
+
time.sleep(5)
|
| 604 |
+
attempt += 1
|
| 605 |
+
else:
|
| 606 |
+
print(f"未知任务状态:{task_status}")
|
| 607 |
+
print("完整响应:", json.dumps(result, indent=2, ensure_ascii=False))
|
| 608 |
+
return None
|
| 609 |
+
else:
|
| 610 |
+
print(f"超过最大查询次数({max_attempts}),任务可能仍在处理中")
|
| 611 |
+
return None
|
| 612 |
+
|
| 613 |
+
except Exception as e:
|
| 614 |
+
print(f"发生异常: {str(e)}")
|
| 615 |
+
print(f"错误详情:\n{traceback.format_exc()}")
|
| 616 |
+
return None
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
"""
|
| 620 |
+
1、通用文字识别,图像数据base64编码后大小不得超过10M
|
| 621 |
+
2、appid、apiSecret、apiKey请到讯飞开放平台控制台获取并填写到此demo中
|
| 622 |
+
3、支持中英文,支持手写和印刷文字。
|
| 623 |
+
4、在倾斜文字上效果有提升,同时支持部分生僻字的识别
|
| 624 |
+
"""
|
| 625 |
+
|
| 626 |
+
# 图像识别接口地址
|
| 627 |
+
URL = "https://api.xf-yun.com/v1/private/sf8e6aca1"
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
class AssembleHeaderException(Exception):
|
| 631 |
+
def __init__(self, msg):
|
| 632 |
+
self.message = msg
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
class Url:
|
| 636 |
+
def __init__(self, host, path, schema):
|
| 637 |
+
self.host = host
|
| 638 |
+
self.path = path
|
| 639 |
+
self.schema = schema
|
| 640 |
+
pass
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
# calculate sha256 and encode to base64
|
| 644 |
+
def sha256base64(data):
|
| 645 |
+
sha256 = hashlib.sha256()
|
| 646 |
+
sha256.update(data)
|
| 647 |
+
digest = base64.b64encode(sha256.digest()).decode(encoding="utf-8")
|
| 648 |
+
return digest
|
| 649 |
+
|
| 650 |
+
|
| 651 |
+
def parse_url(requset_url):
|
| 652 |
+
stidx = requset_url.index("://")
|
| 653 |
+
host = requset_url[stidx + 3 :]
|
| 654 |
+
schema = requset_url[: stidx + 3]
|
| 655 |
+
edidx = host.index("/")
|
| 656 |
+
if edidx <= 0:
|
| 657 |
+
raise AssembleHeaderException("invalid request url:" + requset_url)
|
| 658 |
+
path = host[edidx:]
|
| 659 |
+
host = host[:edidx]
|
| 660 |
+
u = Url(host, path, schema)
|
| 661 |
+
return u
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
# build websocket auth request url
|
| 665 |
+
def assemble_ws_auth_url(requset_url, method="POST", api_key="", api_secret=""):
|
| 666 |
+
u = parse_url(requset_url)
|
| 667 |
+
host = u.host
|
| 668 |
+
path = u.path
|
| 669 |
+
now = datetime.datetime.now()
|
| 670 |
+
date = format_date_time(mktime(now.timetuple()))
|
| 671 |
+
# print(date) # 可选:打印Date值
|
| 672 |
+
|
| 673 |
+
signature_origin = "host: {}\ndate: {}\n{} {} HTTP/1.1".format(
|
| 674 |
+
host, date, method, path
|
| 675 |
+
)
|
| 676 |
+
# print(signature_origin) # 可选:打印签名原文
|
| 677 |
+
signature_sha = hmac.new(
|
| 678 |
+
api_secret.encode("utf-8"),
|
| 679 |
+
signature_origin.encode("utf-8"),
|
| 680 |
+
digestmod=hashlib.sha256,
|
| 681 |
+
).digest()
|
| 682 |
+
signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8")
|
| 683 |
+
authorization_origin = (
|
| 684 |
+
'api_key="%s", algorithm="%s", headers="%s", signature="%s"'
|
| 685 |
+
% (api_key, "hmac-sha256", "host date request-line", signature_sha)
|
| 686 |
+
)
|
| 687 |
+
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(
|
| 688 |
+
encoding="utf-8"
|
| 689 |
+
)
|
| 690 |
+
# print(authorization_origin) # 可选:打印鉴权原文
|
| 691 |
+
values = {"host": host, "date": date, "authorization": authorization}
|
| 692 |
+
|
| 693 |
+
return requset_url + "?" + urlencode(values)
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
def image_to_str(endpoint=None, key=None, unused_param=None, file_path=None):
|
| 697 |
+
"""
|
| 698 |
+
调用Azure Computer Vision API识别图片中的文字。
|
| 699 |
+
|
| 700 |
+
参数:
|
| 701 |
+
endpoint (str): Azure Computer Vision endpoint URL。
|
| 702 |
+
key (str): Azure Computer Vision API key。
|
| 703 |
+
unused_param (str): 未使用的参数,保持兼容性。
|
| 704 |
+
file_path (str): 图片文件路径。
|
| 705 |
+
|
| 706 |
+
返回值:
|
| 707 |
+
str: 图片中的文字识别结果,如果发生错误则返回None。
|
| 708 |
+
"""
|
| 709 |
+
|
| 710 |
+
# 默认配置
|
| 711 |
+
if endpoint is None:
|
| 712 |
+
endpoint = "https://ai-siyuwang5414995ai361208251338.cognitiveservices.azure.com/"
|
| 713 |
+
if key is None:
|
| 714 |
+
key = "45PYY2Av9CdMCveAjVG43MGKrnHzSxdiFTK9mWBgrOsMAHavxKj0JQQJ99BDACHYHv6XJ3w3AAAAACOGeVpQ"
|
| 715 |
+
|
| 716 |
+
try:
|
| 717 |
+
# 读取图片文件
|
| 718 |
+
with open(file_path, "rb") as f:
|
| 719 |
+
image_data = f.read()
|
| 720 |
+
|
| 721 |
+
# 构造请求URL
|
| 722 |
+
analyze_url = endpoint.rstrip('/') + "/vision/v3.2/read/analyze"
|
| 723 |
+
|
| 724 |
+
# 设置请求头
|
| 725 |
+
headers = {
|
| 726 |
+
'Ocp-Apim-Subscription-Key': key,
|
| 727 |
+
'Content-Type': 'application/octet-stream'
|
| 728 |
+
}
|
| 729 |
+
|
| 730 |
+
# 发送POST请求开始分析
|
| 731 |
+
response = requests.post(analyze_url, headers=headers, data=image_data)
|
| 732 |
+
|
| 733 |
+
if response.status_code != 202:
|
| 734 |
+
print(f"分析请求失败: {response.status_code}, {response.text}")
|
| 735 |
+
return None
|
| 736 |
+
|
| 737 |
+
# 获取操作位置
|
| 738 |
+
operation_url = response.headers["Operation-Location"]
|
| 739 |
+
|
| 740 |
+
# 轮询结果
|
| 741 |
+
import time
|
| 742 |
+
while True:
|
| 743 |
+
result_response = requests.get(operation_url, headers={'Ocp-Apim-Subscription-Key': key})
|
| 744 |
+
result = result_response.json()
|
| 745 |
+
|
| 746 |
+
if result["status"] == "succeeded":
|
| 747 |
+
# 提取文字
|
| 748 |
+
text_results = []
|
| 749 |
+
if "analyzeResult" in result and "readResults" in result["analyzeResult"]:
|
| 750 |
+
for read_result in result["analyzeResult"]["readResults"]:
|
| 751 |
+
for line in read_result["lines"]:
|
| 752 |
+
text_results.append(line["text"])
|
| 753 |
+
|
| 754 |
+
return " ".join(text_results) if text_results else ""
|
| 755 |
+
|
| 756 |
+
elif result["status"] == "failed":
|
| 757 |
+
print(f"文字识别失败: {result}")
|
| 758 |
+
return None
|
| 759 |
+
|
| 760 |
+
# 等待1秒后重试
|
| 761 |
+
time.sleep(1)
|
| 762 |
+
|
| 763 |
+
except Exception as e:
|
| 764 |
+
print(f"发生异常: {e}")
|
| 765 |
+
return None
|
| 766 |
+
|
| 767 |
+
|
| 768 |
+
if __name__ == "__main__":
|
| 769 |
+
# 输入讯飞开放平台的 appid,secret、key 和文件路径
|
| 770 |
+
appid = "33c1b63d"
|
| 771 |
+
apikey = "40bf7cd82e31ace30a9cfb76309a43a3"
|
| 772 |
+
apisecret = "OTY1YzIyZWM3YTg0OWZiMGE2ZjA2ZmE4"
|
| 773 |
+
audio_path = r"audio_sample_little.wav" # 确保文件路径正确
|
| 774 |
+
image_path = r"1.png" # 确保文件路径正确
|
| 775 |
+
|
| 776 |
+
# 音频转文字
|
| 777 |
+
audio_text = audio_to_str(appid, apikey, apisecret, audio_path)
|
| 778 |
+
# 图片转文字
|
| 779 |
+
image_text = image_to_str(endpoint="https://ai-siyuwang5414995ai361208251338.cognitiveservices.azure.com/", key="45PYY2Av9CdMCveAjVG43MGKrnHzSxdiFTK9mWBgrOsMAHavxKj0JQQJ99BDACHYHv6XJ3w3AAAAACOGeVpQ", unused_param=None, file_path=image_path)
|
| 780 |
+
|
| 781 |
+
print("-"* 20)
|
| 782 |
+
|
| 783 |
+
print("\n音频转文字结果:", audio_text)
|
| 784 |
+
print("\n图片转文字结果:", image_text)
|
| 785 |
+
|
| 786 |
+
|
| 787 |
+
def azure_speech_to_text(speech_key, speech_region, audio_file_path):
|
| 788 |
+
"""
|
| 789 |
+
使用Azure Speech服务将音频文件转换为文本。
|
| 790 |
+
|
| 791 |
+
参数:
|
| 792 |
+
speech_key (str): Azure Speech服务的API密钥。
|
| 793 |
+
speech_region (str): Azure Speech服务的区域。
|
| 794 |
+
audio_file_path (str): 音频文件路径。
|
| 795 |
+
|
| 796 |
+
返回值:
|
| 797 |
+
str: 转换后的文本,如果发生错误则返回None。
|
| 798 |
+
"""
|
| 799 |
+
try:
|
| 800 |
+
# 设置语音配置
|
| 801 |
+
speech_config = speechsdk.SpeechConfig(subscription=speech_key, region=speech_region)
|
| 802 |
+
speech_config.speech_recognition_language = "zh-CN" # 设置为中文
|
| 803 |
+
|
| 804 |
+
# 设置音频配置
|
| 805 |
+
audio_config = speechsdk.audio.AudioConfig(filename=audio_file_path)
|
| 806 |
+
|
| 807 |
+
# 创建语音识别器
|
| 808 |
+
speech_recognizer = speechsdk.SpeechRecognizer(speech_config=speech_config, audio_config=audio_config)
|
| 809 |
+
|
| 810 |
+
# 执行语音识别
|
| 811 |
+
result = speech_recognizer.recognize_once()
|
| 812 |
+
|
| 813 |
+
# 检查识别结果
|
| 814 |
+
if result.reason == speechsdk.ResultReason.RecognizedSpeech:
|
| 815 |
+
print(f"Azure Speech识别成功: {result.text}")
|
| 816 |
+
return result.text
|
| 817 |
+
elif result.reason == speechsdk.ResultReason.NoMatch:
|
| 818 |
+
print("Azure Speech未识别到语音")
|
| 819 |
+
return None
|
| 820 |
+
elif result.reason == speechsdk.ResultReason.Canceled:
|
| 821 |
+
cancellation_details = result.cancellation_details
|
| 822 |
+
print(f"Azure Speech识别被取消: {cancellation_details.reason}")
|
| 823 |
+
if cancellation_details.reason == speechsdk.CancellationReason.Error:
|
| 824 |
+
print(f"错误详情: {cancellation_details.error_details}")
|
| 825 |
+
return None
|
| 826 |
+
except Exception as e:
|
| 827 |
+
print(f"Azure Speech识别出错: {str(e)}")
|
| 828 |
+
return None
|